audio_resampling_tutorial.py 16.3 KB
Newer Older
1
2
3
4
5
# -*- coding: utf-8 -*-
"""
Audio Resampling
================

6
7
**Author**: `Caroline Chen <carolinechen@meta.com>`__, `Moto Hira <moto@meta.com>`__

8
This tutorial shows how to use torchaudio's resampling API.
9
10
11
12
13
14
15
16
17
18
19
"""

import torch
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T

print(torch.__version__)
print(torchaudio.__version__)

######################################################################
20
21
# Preparation
# -----------
22
#
23
24
25
26
27
28
29
30
31
# First, we import the modules and define the helper functions.
#
# .. note::
#    When running this tutorial in Google Colab, install the required packages
#    with the following.
#
#    .. code::
#
#       !pip install librosa
32
33
34
35
36
37
38

import math
import time

import librosa
import matplotlib.pyplot as plt
import pandas as pd
39
from IPython.display import Audio, display
40

41
42
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
43
44
45
46
47

DEFAULT_OFFSET = 201


def _get_log_freq(sample_rate, max_sweep_rate, offset):
48
    """Get freqs evenly spaced out in log-scale, between [0, max_sweep_rate // 2]
49

50
51
52
53
    offset is used to avoid negative infinity `log(offset + x)`.

    """
    start, stop = math.log(offset), math.log(offset + max_sweep_rate // 2)
54
    return torch.exp(torch.linspace(start, stop, sample_rate, dtype=torch.double)) - offset
55
56
57


def _get_inverse_log_freq(freq, sample_rate, offset):
58
59
60
61
    """Find the time where the given frequency is given by _get_log_freq"""
    half = sample_rate // 2
    return sample_rate * (math.log(1 + freq / offset) / math.log(1 + half / offset))

62
63

def _get_freq_ticks(sample_rate, offset, f_max):
64
65
66
67
68
    # Given the original sample rate used for generating the sweep,
    # find the x-axis value where the log-scale major frequency values fall in
    time, freq = [], []
    for exp in range(2, 5):
        for v in range(1, 10):
69
            f = v * 10**exp
70
71
72
73
74
75
76
77
78
            if f < sample_rate // 2:
                t = _get_inverse_log_freq(f, sample_rate, offset) / sample_rate
                time.append(t)
                freq.append(f)
    t_max = _get_inverse_log_freq(f_max, sample_rate, offset) / sample_rate
    time.append(t_max)
    freq.append(f_max)
    return time, freq

79
80

def get_sine_sweep(sample_rate, offset=DEFAULT_OFFSET):
81
82
83
84
85
86
87
88
89
90
91
92
    max_sweep_rate = sample_rate
    freq = _get_log_freq(sample_rate, max_sweep_rate, offset)
    delta = 2 * math.pi * freq / sample_rate
    cummulative = torch.cumsum(delta, dim=0)
    signal = torch.sin(cummulative).unsqueeze(dim=0)
    return signal


def plot_sweep(
    waveform,
    sample_rate,
    title,
93
    max_sweep_rate=48000,
94
95
96
97
98
99
100
    offset=DEFAULT_OFFSET,
):
    x_ticks = [100, 500, 1000, 5000, 10000, 20000, max_sweep_rate // 2]
    y_ticks = [1000, 5000, 10000, 20000, sample_rate // 2]

    time, freq = _get_freq_ticks(max_sweep_rate, offset, sample_rate // 2)
    freq_x = [f if f in x_ticks and f <= max_sweep_rate // 2 else None for f in freq]
101
    freq_y = [f for f in freq if f in y_ticks and 1000 <= f <= sample_rate // 2]
102
103

    figure, axis = plt.subplots(1, 1)
104
    _, _, _, cax = axis.specgram(waveform[0].numpy(), Fs=sample_rate)
105
106
107
108
109
110
111
    plt.xticks(time, freq_x)
    plt.yticks(freq_y, freq_y)
    axis.set_xlabel("Original Signal Frequency (Hz, log scale)")
    axis.set_ylabel("Waveform Frequency (Hz)")
    axis.xaxis.grid(True, alpha=0.67)
    axis.yaxis.grid(True, alpha=0.67)
    figure.suptitle(f"{title} (sample rate: {sample_rate} Hz)")
112
    plt.colorbar(cax)
113
114
    plt.show(block=True)

115
116

######################################################################
117
118
119
# Resampling Overview
# -------------------
#
120
# To resample an audio waveform from one freqeuncy to another, you can use
121
122
123
124
125
# :py:func:`torchaudio.transforms.Resample` or
# :py:func:`torchaudio.functional.resample`.
# ``transforms.Resample`` precomputes and caches the kernel used for resampling,
# while ``functional.resample`` computes it on the fly, so using
# ``torchaudio.transforms.Resample`` will result in a speedup when resampling
126
127
128
129
130
131
# multiple waveforms using the same parameters (see Benchmarking section).
#
# Both resampling methods use `bandlimited sinc
# interpolation <https://ccrma.stanford.edu/~jos/resample/>`__ to compute
# signal values at arbitrary time steps. The implementation involves
# convolution, so we can take advantage of GPU / multithreading for
132
133
134
135
136
137
138
139
# performance improvements.
#
# .. note::
#
#    When using resampling in multiple subprocesses, such as data loading
#    with multiple worker processes, your application might create more
#    threads than your system can handle efficiently.
#    Setting ``torch.set_num_threads(1)`` might help in this case.
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#
# Because a finite number of samples can only represent a finite number of
# frequencies, resampling does not produce perfect results, and a variety
# of parameters can be used to control for its quality and computational
# speed. We demonstrate these properties through resampling a logarithmic
# sine sweep, which is a sine wave that increases exponentially in
# frequency over time.
#
# The spectrograms below show the frequency representation of the signal,
# where the x-axis corresponds to the frequency of the original
# waveform (in log scale), y-axis the frequency of the
# plotted waveform, and color intensity the amplitude.
#

sample_rate = 48000
waveform = get_sine_sweep(sample_rate)
156

157
plot_sweep(waveform, sample_rate, title="Original Waveform")
158
Audio(waveform.numpy()[0], rate=sample_rate)
159

160
161
162
163
164
165
166
167
######################################################################
#
# Now we resample (downsample) it.
#
# We see that in the spectrogram of the resampled waveform, there is an
# artifact, which was not present in the original waveform.

resample_rate = 32000
168
169
170
resampler = T.Resample(sample_rate, resample_rate, dtype=waveform.dtype)
resampled_waveform = resampler(waveform)

171
172
plot_sweep(resampled_waveform, resample_rate, title="Resampled Waveform")
Audio(resampled_waveform.numpy()[0], rate=resample_rate)
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192

######################################################################
# Controling resampling quality with parameters
# ---------------------------------------------
#
# Lowpass filter width
# ~~~~~~~~~~~~~~~~~~~~
#
# Because the filter used for interpolation extends infinitely, the
# ``lowpass_filter_width`` parameter is used to control for the width of
# the filter to use to window the interpolation. It is also referred to as
# the number of zero crossings, since the interpolation passes through
# zero at every time unit. Using a larger ``lowpass_filter_width``
# provides a sharper, more precise filter, but is more computationally
# expensive.
#

sample_rate = 48000
resample_rate = 32000

193
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=6)
194
195
plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=6")

196
197
198
######################################################################
#

199
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=128)
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=128")

######################################################################
# Rolloff
# ~~~~~~~
#
# The ``rolloff`` parameter is represented as a fraction of the Nyquist
# frequency, which is the maximal frequency representable by a given
# finite sample rate. ``rolloff`` determines the lowpass filter cutoff and
# controls the degree of aliasing, which takes place when frequencies
# higher than the Nyquist are mapped to lower frequencies. A lower rolloff
# will therefore reduce the amount of aliasing, but it will also reduce
# some of the higher frequencies.
#


sample_rate = 48000
resample_rate = 32000

resampled_waveform = F.resample(waveform, sample_rate, resample_rate, rolloff=0.99)
plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.99")

222
223
224
######################################################################
#

225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, rolloff=0.8)
plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.8")


######################################################################
# Window function
# ~~~~~~~~~~~~~~~
#
# By default, ``torchaudio``’s resample uses the Hann window filter, which is
# a weighted cosine function. It additionally supports the Kaiser window,
# which is a near optimal window function that contains an additional
# ``beta`` parameter that allows for the design of the smoothness of the
# filter and width of impulse. This can be controlled using the
# ``resampling_method`` parameter.
#


sample_rate = 48000
resample_rate = 32000

245
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interpolation")
246
247
plot_sweep(resampled_waveform, resample_rate, title="Hann Window Default")

248
249
250
######################################################################
#

251
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="kaiser_window")
252
253
254
255
256
257
258
259
260
261
262
263
264
265
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default")


######################################################################
# Comparison against librosa
# --------------------------
#
# ``torchaudio``’s resample function can be used to produce results similar to
# that of librosa (resampy)’s kaiser window resampling, with some noise
#

sample_rate = 48000
resample_rate = 32000

266
######################################################################
267
# kaiser_best
268
269
# ~~~~~~~~~~~
#
270
271
272
273
274
275
276
resampled_waveform = F.resample(
    waveform,
    sample_rate,
    resample_rate,
    lowpass_filter_width=64,
    rolloff=0.9475937167399596,
    resampling_method="kaiser_window",
277
    beta=14.769656459379492,
278
279
280
)
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Best (torchaudio)")

281
282
283
######################################################################
#

284
librosa_resampled_waveform = torch.from_numpy(
hwangjeff's avatar
hwangjeff committed
285
    librosa.resample(waveform.squeeze().numpy(), orig_sr=sample_rate, target_sr=resample_rate, res_type="kaiser_best")
286
).unsqueeze(0)
287
plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Best (librosa)")
288

289
290
291
######################################################################
#

292
293
294
mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item()
print("torchaudio and librosa kaiser best MSE:", mse)

295
######################################################################
296
# kaiser_fast
297
298
# ~~~~~~~~~~~
#
299
300
301
302
303
304
305
resampled_waveform = F.resample(
    waveform,
    sample_rate,
    resample_rate,
    lowpass_filter_width=16,
    rolloff=0.85,
    resampling_method="kaiser_window",
306
307
    beta=8.555504641634386,
)
308
309
310
311
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Fast (torchaudio)")

######################################################################
#
312
313

librosa_resampled_waveform = torch.from_numpy(
hwangjeff's avatar
hwangjeff committed
314
    librosa.resample(waveform.squeeze().numpy(), orig_sr=sample_rate, target_sr=resample_rate, res_type="kaiser_fast")
315
).unsqueeze(0)
316
plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Fast (librosa)")
317

318
319
320
######################################################################
#

321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item()
print("torchaudio and librosa kaiser fast MSE:", mse)

######################################################################
# Performance Benchmarking
# ------------------------
#
# Below are benchmarks for downsampling and upsampling waveforms between
# two pairs of sampling rates. We demonstrate the performance implications
# that the ``lowpass_filter_wdith``, window type, and sample rates can
# have. Additionally, we provide a comparison against ``librosa``\ ’s
# ``kaiser_best`` and ``kaiser_fast`` using their corresponding parameters
# in ``torchaudio``.
#
# To elaborate on the results:
#
# - a larger ``lowpass_filter_width`` results in a larger resampling kernel,
#   and therefore increases computation time for both the kernel computation
#   and convolution
# - using ``kaiser_window`` results in longer computation times than the default
#   ``sinc_interpolation`` because it is more complex to compute the intermediate
#   window values - a large GCD between the sample and resample rate will result
#   in a simplification that allows for a smaller kernel and faster kernel computation.
#


347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
def benchmark_resample(
    method,
    waveform,
    sample_rate,
    resample_rate,
    lowpass_filter_width=6,
    rolloff=0.99,
    resampling_method="sinc_interpolation",
    beta=None,
    librosa_type=None,
    iters=5,
):
    if method == "functional":
        begin = time.monotonic()
        for _ in range(iters):
            F.resample(
                waveform,
                sample_rate,
                resample_rate,
                lowpass_filter_width=lowpass_filter_width,
                rolloff=rolloff,
                resampling_method=resampling_method,
            )
        elapsed = time.monotonic() - begin
        return elapsed / iters
    elif method == "transforms":
        resampler = T.Resample(
            sample_rate,
            resample_rate,
            lowpass_filter_width=lowpass_filter_width,
            rolloff=rolloff,
            resampling_method=resampling_method,
            dtype=waveform.dtype,
        )
        begin = time.monotonic()
        for _ in range(iters):
            resampler(waveform)
        elapsed = time.monotonic() - begin
        return elapsed / iters
    elif method == "librosa":
        waveform_np = waveform.squeeze().numpy()
        begin = time.monotonic()
        for _ in range(iters):
            librosa.resample(waveform_np, orig_sr=sample_rate, target_sr=resample_rate, res_type=librosa_type)
        elapsed = time.monotonic() - begin
        return elapsed / iters


######################################################################
#

398
399
400
401
402
403
404
405
configs = {
    "downsample (48 -> 44.1 kHz)": [48000, 44100],
    "downsample (16 -> 8 kHz)": [16000, 8000],
    "upsample (44.1 -> 48 kHz)": [44100, 48000],
    "upsample (8 -> 16 kHz)": [8000, 16000],
}

for label in configs:
406
407
408
409
410
411
    times, rows = [], []
    sample_rate = configs[label][0]
    resample_rate = configs[label][1]
    waveform = get_sine_sweep(sample_rate)

    # sinc 64 zero-crossings
412
413
    f_time = benchmark_resample("functional", waveform, sample_rate, resample_rate, lowpass_filter_width=64)
    t_time = benchmark_resample("transforms", waveform, sample_rate, resample_rate, lowpass_filter_width=64)
414
415
416
417
    times.append([None, 1000 * f_time, 1000 * t_time])
    rows.append("sinc (width 64)")

    # sinc 6 zero-crossings
418
419
    f_time = benchmark_resample("functional", waveform, sample_rate, resample_rate, lowpass_filter_width=16)
    t_time = benchmark_resample("transforms", waveform, sample_rate, resample_rate, lowpass_filter_width=16)
420
421
422
423
    times.append([None, 1000 * f_time, 1000 * t_time])
    rows.append("sinc (width 16)")

    # kaiser best
424
    lib_time = benchmark_resample("librosa", waveform, sample_rate, resample_rate, librosa_type="kaiser_best")
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
    f_time = benchmark_resample(
        "functional",
        waveform,
        sample_rate,
        resample_rate,
        lowpass_filter_width=64,
        rolloff=0.9475937167399596,
        resampling_method="kaiser_window",
        beta=14.769656459379492,
    )
    t_time = benchmark_resample(
        "transforms",
        waveform,
        sample_rate,
        resample_rate,
        lowpass_filter_width=64,
        rolloff=0.9475937167399596,
        resampling_method="kaiser_window",
        beta=14.769656459379492,
    )
    times.append([1000 * lib_time, 1000 * f_time, 1000 * t_time])
    rows.append("kaiser_best")

    # kaiser fast
449
    lib_time = benchmark_resample("librosa", waveform, sample_rate, resample_rate, librosa_type="kaiser_fast")
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
    f_time = benchmark_resample(
        "functional",
        waveform,
        sample_rate,
        resample_rate,
        lowpass_filter_width=16,
        rolloff=0.85,
        resampling_method="kaiser_window",
        beta=8.555504641634386,
    )
    t_time = benchmark_resample(
        "transforms",
        waveform,
        sample_rate,
        resample_rate,
        lowpass_filter_width=16,
        rolloff=0.85,
        resampling_method="kaiser_window",
        beta=8.555504641634386,
    )
    times.append([1000 * lib_time, 1000 * f_time, 1000 * t_time])
    rows.append("kaiser_fast")

473
    df = pd.DataFrame(times, columns=["librosa", "functional", "transforms"], index=rows)
474
    df.columns = pd.MultiIndex.from_product([[f"{label} time (ms)"], df.columns])
475
476
477

    print(f"torchaudio: {torchaudio.__version__}")
    print(f"librosa: {librosa.__version__}")
478
    display(df.round(2))