audio_resampling_tutorial.py 16.3 KB
Newer Older
1
2
3
4
5
# -*- coding: utf-8 -*-
"""
Audio Resampling
================

6
This tutorial shows how to use torchaudio's resampling API.
7
8
9
10
11
12
13
14
15
16
17
"""

import torch
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T

print(torch.__version__)
print(torchaudio.__version__)

######################################################################
18
19
# Preparation
# -----------
20
#
21
22
23
24
25
26
27
28
29
# First, we import the modules and define the helper functions.
#
# .. note::
#    When running this tutorial in Google Colab, install the required packages
#    with the following.
#
#    .. code::
#
#       !pip install librosa
30
31
32
33
34
35
36

import math
import time

import librosa
import matplotlib.pyplot as plt
import pandas as pd
37
from IPython.display import Audio, display
38

39
40
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
41
42
43
44
45

DEFAULT_OFFSET = 201


def _get_log_freq(sample_rate, max_sweep_rate, offset):
46
    """Get freqs evenly spaced out in log-scale, between [0, max_sweep_rate // 2]
47

48
49
50
51
    offset is used to avoid negative infinity `log(offset + x)`.

    """
    start, stop = math.log(offset), math.log(offset + max_sweep_rate // 2)
52
    return torch.exp(torch.linspace(start, stop, sample_rate, dtype=torch.double)) - offset
53
54
55


def _get_inverse_log_freq(freq, sample_rate, offset):
56
57
58
59
    """Find the time where the given frequency is given by _get_log_freq"""
    half = sample_rate // 2
    return sample_rate * (math.log(1 + freq / offset) / math.log(1 + half / offset))

60
61

def _get_freq_ticks(sample_rate, offset, f_max):
62
63
64
65
66
    # Given the original sample rate used for generating the sweep,
    # find the x-axis value where the log-scale major frequency values fall in
    time, freq = [], []
    for exp in range(2, 5):
        for v in range(1, 10):
67
            f = v * 10**exp
68
69
70
71
72
73
74
75
76
            if f < sample_rate // 2:
                t = _get_inverse_log_freq(f, sample_rate, offset) / sample_rate
                time.append(t)
                freq.append(f)
    t_max = _get_inverse_log_freq(f_max, sample_rate, offset) / sample_rate
    time.append(t_max)
    freq.append(f_max)
    return time, freq

77
78

def get_sine_sweep(sample_rate, offset=DEFAULT_OFFSET):
79
80
81
82
83
84
85
86
87
88
89
90
    max_sweep_rate = sample_rate
    freq = _get_log_freq(sample_rate, max_sweep_rate, offset)
    delta = 2 * math.pi * freq / sample_rate
    cummulative = torch.cumsum(delta, dim=0)
    signal = torch.sin(cummulative).unsqueeze(dim=0)
    return signal


def plot_sweep(
    waveform,
    sample_rate,
    title,
91
    max_sweep_rate=48000,
92
93
94
95
96
97
98
    offset=DEFAULT_OFFSET,
):
    x_ticks = [100, 500, 1000, 5000, 10000, 20000, max_sweep_rate // 2]
    y_ticks = [1000, 5000, 10000, 20000, sample_rate // 2]

    time, freq = _get_freq_ticks(max_sweep_rate, offset, sample_rate // 2)
    freq_x = [f if f in x_ticks and f <= max_sweep_rate // 2 else None for f in freq]
99
    freq_y = [f for f in freq if f in y_ticks and 1000 <= f <= sample_rate // 2]
100
101

    figure, axis = plt.subplots(1, 1)
102
    _, _, _, cax = axis.specgram(waveform[0].numpy(), Fs=sample_rate)
103
104
105
106
107
108
109
    plt.xticks(time, freq_x)
    plt.yticks(freq_y, freq_y)
    axis.set_xlabel("Original Signal Frequency (Hz, log scale)")
    axis.set_ylabel("Waveform Frequency (Hz)")
    axis.xaxis.grid(True, alpha=0.67)
    axis.yaxis.grid(True, alpha=0.67)
    figure.suptitle(f"{title} (sample rate: {sample_rate} Hz)")
110
    plt.colorbar(cax)
111
112
    plt.show(block=True)

113
114

######################################################################
115
116
117
# Resampling Overview
# -------------------
#
118
# To resample an audio waveform from one freqeuncy to another, you can use
119
120
121
122
123
# :py:func:`torchaudio.transforms.Resample` or
# :py:func:`torchaudio.functional.resample`.
# ``transforms.Resample`` precomputes and caches the kernel used for resampling,
# while ``functional.resample`` computes it on the fly, so using
# ``torchaudio.transforms.Resample`` will result in a speedup when resampling
124
125
126
127
128
129
# multiple waveforms using the same parameters (see Benchmarking section).
#
# Both resampling methods use `bandlimited sinc
# interpolation <https://ccrma.stanford.edu/~jos/resample/>`__ to compute
# signal values at arbitrary time steps. The implementation involves
# convolution, so we can take advantage of GPU / multithreading for
130
131
132
133
134
135
136
137
# performance improvements.
#
# .. note::
#
#    When using resampling in multiple subprocesses, such as data loading
#    with multiple worker processes, your application might create more
#    threads than your system can handle efficiently.
#    Setting ``torch.set_num_threads(1)`` might help in this case.
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#
# Because a finite number of samples can only represent a finite number of
# frequencies, resampling does not produce perfect results, and a variety
# of parameters can be used to control for its quality and computational
# speed. We demonstrate these properties through resampling a logarithmic
# sine sweep, which is a sine wave that increases exponentially in
# frequency over time.
#
# The spectrograms below show the frequency representation of the signal,
# where the x-axis corresponds to the frequency of the original
# waveform (in log scale), y-axis the frequency of the
# plotted waveform, and color intensity the amplitude.
#

sample_rate = 48000
waveform = get_sine_sweep(sample_rate)
154

155
plot_sweep(waveform, sample_rate, title="Original Waveform")
156
Audio(waveform.numpy()[0], rate=sample_rate)
157

158
159
160
161
162
163
164
165
######################################################################
#
# Now we resample (downsample) it.
#
# We see that in the spectrogram of the resampled waveform, there is an
# artifact, which was not present in the original waveform.

resample_rate = 32000
166
167
168
resampler = T.Resample(sample_rate, resample_rate, dtype=waveform.dtype)
resampled_waveform = resampler(waveform)

169
170
plot_sweep(resampled_waveform, resample_rate, title="Resampled Waveform")
Audio(resampled_waveform.numpy()[0], rate=resample_rate)
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190

######################################################################
# Controling resampling quality with parameters
# ---------------------------------------------
#
# Lowpass filter width
# ~~~~~~~~~~~~~~~~~~~~
#
# Because the filter used for interpolation extends infinitely, the
# ``lowpass_filter_width`` parameter is used to control for the width of
# the filter to use to window the interpolation. It is also referred to as
# the number of zero crossings, since the interpolation passes through
# zero at every time unit. Using a larger ``lowpass_filter_width``
# provides a sharper, more precise filter, but is more computationally
# expensive.
#

sample_rate = 48000
resample_rate = 32000

191
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=6)
192
193
plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=6")

194
195
196
######################################################################
#

197
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=128)
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=128")

######################################################################
# Rolloff
# ~~~~~~~
#
# The ``rolloff`` parameter is represented as a fraction of the Nyquist
# frequency, which is the maximal frequency representable by a given
# finite sample rate. ``rolloff`` determines the lowpass filter cutoff and
# controls the degree of aliasing, which takes place when frequencies
# higher than the Nyquist are mapped to lower frequencies. A lower rolloff
# will therefore reduce the amount of aliasing, but it will also reduce
# some of the higher frequencies.
#


sample_rate = 48000
resample_rate = 32000

resampled_waveform = F.resample(waveform, sample_rate, resample_rate, rolloff=0.99)
plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.99")

220
221
222
######################################################################
#

223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, rolloff=0.8)
plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.8")


######################################################################
# Window function
# ~~~~~~~~~~~~~~~
#
# By default, ``torchaudio``’s resample uses the Hann window filter, which is
# a weighted cosine function. It additionally supports the Kaiser window,
# which is a near optimal window function that contains an additional
# ``beta`` parameter that allows for the design of the smoothness of the
# filter and width of impulse. This can be controlled using the
# ``resampling_method`` parameter.
#


sample_rate = 48000
resample_rate = 32000

243
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interpolation")
244
245
plot_sweep(resampled_waveform, resample_rate, title="Hann Window Default")

246
247
248
######################################################################
#

249
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="kaiser_window")
250
251
252
253
254
255
256
257
258
259
260
261
262
263
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default")


######################################################################
# Comparison against librosa
# --------------------------
#
# ``torchaudio``’s resample function can be used to produce results similar to
# that of librosa (resampy)’s kaiser window resampling, with some noise
#

sample_rate = 48000
resample_rate = 32000

264
######################################################################
265
# kaiser_best
266
267
# ~~~~~~~~~~~
#
268
269
270
271
272
273
274
resampled_waveform = F.resample(
    waveform,
    sample_rate,
    resample_rate,
    lowpass_filter_width=64,
    rolloff=0.9475937167399596,
    resampling_method="kaiser_window",
275
    beta=14.769656459379492,
276
277
278
)
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Best (torchaudio)")

279
280
281
######################################################################
#

282
librosa_resampled_waveform = torch.from_numpy(
hwangjeff's avatar
hwangjeff committed
283
    librosa.resample(waveform.squeeze().numpy(), orig_sr=sample_rate, target_sr=resample_rate, res_type="kaiser_best")
284
).unsqueeze(0)
285
plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Best (librosa)")
286

287
288
289
######################################################################
#

290
291
292
mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item()
print("torchaudio and librosa kaiser best MSE:", mse)

293
######################################################################
294
# kaiser_fast
295
296
# ~~~~~~~~~~~
#
297
298
299
300
301
302
303
resampled_waveform = F.resample(
    waveform,
    sample_rate,
    resample_rate,
    lowpass_filter_width=16,
    rolloff=0.85,
    resampling_method="kaiser_window",
304
305
    beta=8.555504641634386,
)
306
307
308
309
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Fast (torchaudio)")

######################################################################
#
310
311

librosa_resampled_waveform = torch.from_numpy(
hwangjeff's avatar
hwangjeff committed
312
    librosa.resample(waveform.squeeze().numpy(), orig_sr=sample_rate, target_sr=resample_rate, res_type="kaiser_fast")
313
).unsqueeze(0)
314
plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Fast (librosa)")
315

316
317
318
######################################################################
#

319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item()
print("torchaudio and librosa kaiser fast MSE:", mse)

######################################################################
# Performance Benchmarking
# ------------------------
#
# Below are benchmarks for downsampling and upsampling waveforms between
# two pairs of sampling rates. We demonstrate the performance implications
# that the ``lowpass_filter_wdith``, window type, and sample rates can
# have. Additionally, we provide a comparison against ``librosa``\ ’s
# ``kaiser_best`` and ``kaiser_fast`` using their corresponding parameters
# in ``torchaudio``.
#
# To elaborate on the results:
#
# - a larger ``lowpass_filter_width`` results in a larger resampling kernel,
#   and therefore increases computation time for both the kernel computation
#   and convolution
# - using ``kaiser_window`` results in longer computation times than the default
#   ``sinc_interpolation`` because it is more complex to compute the intermediate
#   window values - a large GCD between the sample and resample rate will result
#   in a simplification that allows for a smaller kernel and faster kernel computation.
#


345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
def benchmark_resample(
    method,
    waveform,
    sample_rate,
    resample_rate,
    lowpass_filter_width=6,
    rolloff=0.99,
    resampling_method="sinc_interpolation",
    beta=None,
    librosa_type=None,
    iters=5,
):
    if method == "functional":
        begin = time.monotonic()
        for _ in range(iters):
            F.resample(
                waveform,
                sample_rate,
                resample_rate,
                lowpass_filter_width=lowpass_filter_width,
                rolloff=rolloff,
                resampling_method=resampling_method,
            )
        elapsed = time.monotonic() - begin
        return elapsed / iters
    elif method == "transforms":
        resampler = T.Resample(
            sample_rate,
            resample_rate,
            lowpass_filter_width=lowpass_filter_width,
            rolloff=rolloff,
            resampling_method=resampling_method,
            dtype=waveform.dtype,
        )
        begin = time.monotonic()
        for _ in range(iters):
            resampler(waveform)
        elapsed = time.monotonic() - begin
        return elapsed / iters
    elif method == "librosa":
        waveform_np = waveform.squeeze().numpy()
        begin = time.monotonic()
        for _ in range(iters):
            librosa.resample(waveform_np, orig_sr=sample_rate, target_sr=resample_rate, res_type=librosa_type)
        elapsed = time.monotonic() - begin
        return elapsed / iters


######################################################################
#

396
397
398
399
400
401
402
403
configs = {
    "downsample (48 -> 44.1 kHz)": [48000, 44100],
    "downsample (16 -> 8 kHz)": [16000, 8000],
    "upsample (44.1 -> 48 kHz)": [44100, 48000],
    "upsample (8 -> 16 kHz)": [8000, 16000],
}

for label in configs:
404
405
406
407
408
409
    times, rows = [], []
    sample_rate = configs[label][0]
    resample_rate = configs[label][1]
    waveform = get_sine_sweep(sample_rate)

    # sinc 64 zero-crossings
410
411
    f_time = benchmark_resample("functional", waveform, sample_rate, resample_rate, lowpass_filter_width=64)
    t_time = benchmark_resample("transforms", waveform, sample_rate, resample_rate, lowpass_filter_width=64)
412
413
414
415
    times.append([None, 1000 * f_time, 1000 * t_time])
    rows.append("sinc (width 64)")

    # sinc 6 zero-crossings
416
417
    f_time = benchmark_resample("functional", waveform, sample_rate, resample_rate, lowpass_filter_width=16)
    t_time = benchmark_resample("transforms", waveform, sample_rate, resample_rate, lowpass_filter_width=16)
418
419
420
421
    times.append([None, 1000 * f_time, 1000 * t_time])
    rows.append("sinc (width 16)")

    # kaiser best
422
    lib_time = benchmark_resample("librosa", waveform, sample_rate, resample_rate, librosa_type="kaiser_best")
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
    f_time = benchmark_resample(
        "functional",
        waveform,
        sample_rate,
        resample_rate,
        lowpass_filter_width=64,
        rolloff=0.9475937167399596,
        resampling_method="kaiser_window",
        beta=14.769656459379492,
    )
    t_time = benchmark_resample(
        "transforms",
        waveform,
        sample_rate,
        resample_rate,
        lowpass_filter_width=64,
        rolloff=0.9475937167399596,
        resampling_method="kaiser_window",
        beta=14.769656459379492,
    )
    times.append([1000 * lib_time, 1000 * f_time, 1000 * t_time])
    rows.append("kaiser_best")

    # kaiser fast
447
    lib_time = benchmark_resample("librosa", waveform, sample_rate, resample_rate, librosa_type="kaiser_fast")
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
    f_time = benchmark_resample(
        "functional",
        waveform,
        sample_rate,
        resample_rate,
        lowpass_filter_width=16,
        rolloff=0.85,
        resampling_method="kaiser_window",
        beta=8.555504641634386,
    )
    t_time = benchmark_resample(
        "transforms",
        waveform,
        sample_rate,
        resample_rate,
        lowpass_filter_width=16,
        rolloff=0.85,
        resampling_method="kaiser_window",
        beta=8.555504641634386,
    )
    times.append([1000 * lib_time, 1000 * f_time, 1000 * t_time])
    rows.append("kaiser_fast")

471
    df = pd.DataFrame(times, columns=["librosa", "functional", "transforms"], index=rows)
472
    df.columns = pd.MultiIndex.from_product([[f"{label} time (ms)"], df.columns])
473
474
475

    print(f"torchaudio: {torchaudio.__version__}")
    print(f"librosa: {librosa.__version__}")
476
    display(df.round(2))