audio_resampling_tutorial.py 17.2 KB
Newer Older
1
2
3
4
5
# -*- coding: utf-8 -*-
"""
Audio Resampling
================

6
7
**Author**: `Caroline Chen <carolinechen@meta.com>`__, `Moto Hira <moto@meta.com>`__

8
This tutorial shows how to use torchaudio's resampling API.
9
10
11
12
13
14
15
16
17
18
19
"""

import torch
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T

print(torch.__version__)
print(torchaudio.__version__)

######################################################################
20
21
# Preparation
# -----------
22
#
23
24
# First, we import the modules and define the helper functions.
#
25
26

import math
moto's avatar
moto committed
27
import timeit
28
29

import librosa
moto's avatar
moto committed
30
import matplotlib.colors as mcolors
mayp777's avatar
UPDATE  
mayp777 committed
31
import matplotlib.pyplot as plt
32
import pandas as pd
mayp777's avatar
UPDATE  
mayp777 committed
33
34
import resampy
from IPython.display import Audio
35

mayp777's avatar
UPDATE  
mayp777 committed
36
37
pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", None)
38
39
40
41
42

DEFAULT_OFFSET = 201


def _get_log_freq(sample_rate, max_sweep_rate, offset):
43
    """Get freqs evenly spaced out in log-scale, between [0, max_sweep_rate // 2]
44

45
46
47
48
    offset is used to avoid negative infinity `log(offset + x)`.

    """
    start, stop = math.log(offset), math.log(offset + max_sweep_rate // 2)
49
    return torch.exp(torch.linspace(start, stop, sample_rate, dtype=torch.double)) - offset
50
51
52


def _get_inverse_log_freq(freq, sample_rate, offset):
53
54
55
56
    """Find the time where the given frequency is given by _get_log_freq"""
    half = sample_rate // 2
    return sample_rate * (math.log(1 + freq / offset) / math.log(1 + half / offset))

57
58

def _get_freq_ticks(sample_rate, offset, f_max):
59
60
    # Given the original sample rate used for generating the sweep,
    # find the x-axis value where the log-scale major frequency values fall in
moto's avatar
moto committed
61
    times, freq = [], []
62
63
    for exp in range(2, 5):
        for v in range(1, 10):
64
            f = v * 10**exp
65
66
            if f < sample_rate // 2:
                t = _get_inverse_log_freq(f, sample_rate, offset) / sample_rate
moto's avatar
moto committed
67
                times.append(t)
68
69
                freq.append(f)
    t_max = _get_inverse_log_freq(f_max, sample_rate, offset) / sample_rate
moto's avatar
moto committed
70
    times.append(t_max)
71
    freq.append(f_max)
moto's avatar
moto committed
72
    return times, freq
73

74
75

def get_sine_sweep(sample_rate, offset=DEFAULT_OFFSET):
76
77
78
79
80
81
82
83
84
85
86
87
    max_sweep_rate = sample_rate
    freq = _get_log_freq(sample_rate, max_sweep_rate, offset)
    delta = 2 * math.pi * freq / sample_rate
    cummulative = torch.cumsum(delta, dim=0)
    signal = torch.sin(cummulative).unsqueeze(dim=0)
    return signal


def plot_sweep(
    waveform,
    sample_rate,
    title,
88
    max_sweep_rate=48000,
89
90
91
92
93
94
95
    offset=DEFAULT_OFFSET,
):
    x_ticks = [100, 500, 1000, 5000, 10000, 20000, max_sweep_rate // 2]
    y_ticks = [1000, 5000, 10000, 20000, sample_rate // 2]

    time, freq = _get_freq_ticks(max_sweep_rate, offset, sample_rate // 2)
    freq_x = [f if f in x_ticks and f <= max_sweep_rate // 2 else None for f in freq]
96
    freq_y = [f for f in freq if f in y_ticks and 1000 <= f <= sample_rate // 2]
97
98

    figure, axis = plt.subplots(1, 1)
99
    _, _, _, cax = axis.specgram(waveform[0].numpy(), Fs=sample_rate)
100
101
102
103
104
105
106
    plt.xticks(time, freq_x)
    plt.yticks(freq_y, freq_y)
    axis.set_xlabel("Original Signal Frequency (Hz, log scale)")
    axis.set_ylabel("Waveform Frequency (Hz)")
    axis.xaxis.grid(True, alpha=0.67)
    axis.yaxis.grid(True, alpha=0.67)
    figure.suptitle(f"{title} (sample rate: {sample_rate} Hz)")
107
    plt.colorbar(cax)
108

109
110

######################################################################
111
112
113
# Resampling Overview
# -------------------
#
114
# To resample an audio waveform from one freqeuncy to another, you can use
moto's avatar
moto committed
115
# :py:class:`torchaudio.transforms.Resample` or
116
117
118
119
# :py:func:`torchaudio.functional.resample`.
# ``transforms.Resample`` precomputes and caches the kernel used for resampling,
# while ``functional.resample`` computes it on the fly, so using
# ``torchaudio.transforms.Resample`` will result in a speedup when resampling
120
121
122
123
124
125
# multiple waveforms using the same parameters (see Benchmarking section).
#
# Both resampling methods use `bandlimited sinc
# interpolation <https://ccrma.stanford.edu/~jos/resample/>`__ to compute
# signal values at arbitrary time steps. The implementation involves
# convolution, so we can take advantage of GPU / multithreading for
126
127
128
129
130
131
132
133
# performance improvements.
#
# .. note::
#
#    When using resampling in multiple subprocesses, such as data loading
#    with multiple worker processes, your application might create more
#    threads than your system can handle efficiently.
#    Setting ``torch.set_num_threads(1)`` might help in this case.
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#
# Because a finite number of samples can only represent a finite number of
# frequencies, resampling does not produce perfect results, and a variety
# of parameters can be used to control for its quality and computational
# speed. We demonstrate these properties through resampling a logarithmic
# sine sweep, which is a sine wave that increases exponentially in
# frequency over time.
#
# The spectrograms below show the frequency representation of the signal,
# where the x-axis corresponds to the frequency of the original
# waveform (in log scale), y-axis the frequency of the
# plotted waveform, and color intensity the amplitude.
#

sample_rate = 48000
waveform = get_sine_sweep(sample_rate)
150

151
plot_sweep(waveform, sample_rate, title="Original Waveform")
152
Audio(waveform.numpy()[0], rate=sample_rate)
153

154
155
156
157
158
159
######################################################################
#
# Now we resample (downsample) it.
#
# We see that in the spectrogram of the resampled waveform, there is an
# artifact, which was not present in the original waveform.
moto's avatar
moto committed
160
161
162
# This effect is called aliasing.
# `This page <https://music.arts.uci.edu/dobrian/digitalaudio.htm>`__ has
# an explanation of how it happens, and why it looks like a reflection.
163
164

resample_rate = 32000
165
166
167
resampler = T.Resample(sample_rate, resample_rate, dtype=waveform.dtype)
resampled_waveform = resampler(waveform)

168
169
plot_sweep(resampled_waveform, resample_rate, title="Resampled Waveform")
Audio(resampled_waveform.numpy()[0], rate=resample_rate)
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189

######################################################################
# Controling resampling quality with parameters
# ---------------------------------------------
#
# Lowpass filter width
# ~~~~~~~~~~~~~~~~~~~~
#
# Because the filter used for interpolation extends infinitely, the
# ``lowpass_filter_width`` parameter is used to control for the width of
# the filter to use to window the interpolation. It is also referred to as
# the number of zero crossings, since the interpolation passes through
# zero at every time unit. Using a larger ``lowpass_filter_width``
# provides a sharper, more precise filter, but is more computationally
# expensive.
#

sample_rate = 48000
resample_rate = 32000

190
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=6)
191
192
plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=6")

193
194
195
######################################################################
#

196
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=128)
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=128")

######################################################################
# Rolloff
# ~~~~~~~
#
# The ``rolloff`` parameter is represented as a fraction of the Nyquist
# frequency, which is the maximal frequency representable by a given
# finite sample rate. ``rolloff`` determines the lowpass filter cutoff and
# controls the degree of aliasing, which takes place when frequencies
# higher than the Nyquist are mapped to lower frequencies. A lower rolloff
# will therefore reduce the amount of aliasing, but it will also reduce
# some of the higher frequencies.
#


sample_rate = 48000
resample_rate = 32000

resampled_waveform = F.resample(waveform, sample_rate, resample_rate, rolloff=0.99)
plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.99")

219
220
221
######################################################################
#

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, rolloff=0.8)
plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.8")


######################################################################
# Window function
# ~~~~~~~~~~~~~~~
#
# By default, ``torchaudio``’s resample uses the Hann window filter, which is
# a weighted cosine function. It additionally supports the Kaiser window,
# which is a near optimal window function that contains an additional
# ``beta`` parameter that allows for the design of the smoothness of the
# filter and width of impulse. This can be controlled using the
# ``resampling_method`` parameter.
#


sample_rate = 48000
resample_rate = 32000

mayp777's avatar
UPDATE  
mayp777 committed
242
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interp_hann")
243
244
plot_sweep(resampled_waveform, resample_rate, title="Hann Window Default")

245
246
247
######################################################################
#

mayp777's avatar
UPDATE  
mayp777 committed
248
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interp_kaiser")
249
250
251
252
253
254
255
256
257
258
259
260
261
262
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default")


######################################################################
# Comparison against librosa
# --------------------------
#
# ``torchaudio``’s resample function can be used to produce results similar to
# that of librosa (resampy)’s kaiser window resampling, with some noise
#

sample_rate = 48000
resample_rate = 32000

263
######################################################################
264
# kaiser_best
265
266
# ~~~~~~~~~~~
#
267
268
269
270
271
272
resampled_waveform = F.resample(
    waveform,
    sample_rate,
    resample_rate,
    lowpass_filter_width=64,
    rolloff=0.9475937167399596,
mayp777's avatar
UPDATE  
mayp777 committed
273
    resampling_method="sinc_interp_kaiser",
274
    beta=14.769656459379492,
275
276
277
)
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Best (torchaudio)")

278
279
280
######################################################################
#

281
librosa_resampled_waveform = torch.from_numpy(
hwangjeff's avatar
hwangjeff committed
282
    librosa.resample(waveform.squeeze().numpy(), orig_sr=sample_rate, target_sr=resample_rate, res_type="kaiser_best")
283
).unsqueeze(0)
284
plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Best (librosa)")
285

286
287
288
######################################################################
#

289
290
291
mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item()
print("torchaudio and librosa kaiser best MSE:", mse)

292
######################################################################
293
# kaiser_fast
294
295
# ~~~~~~~~~~~
#
296
297
298
299
300
301
resampled_waveform = F.resample(
    waveform,
    sample_rate,
    resample_rate,
    lowpass_filter_width=16,
    rolloff=0.85,
mayp777's avatar
UPDATE  
mayp777 committed
302
    resampling_method="sinc_interp_kaiser",
303
304
    beta=8.555504641634386,
)
305
306
307
308
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Fast (torchaudio)")

######################################################################
#
309
310

librosa_resampled_waveform = torch.from_numpy(
hwangjeff's avatar
hwangjeff committed
311
    librosa.resample(waveform.squeeze().numpy(), orig_sr=sample_rate, target_sr=resample_rate, res_type="kaiser_fast")
312
).unsqueeze(0)
313
plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Fast (librosa)")
314

315
316
317
######################################################################
#

318
319
320
321
322
323
324
325
326
mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item()
print("torchaudio and librosa kaiser fast MSE:", mse)

######################################################################
# Performance Benchmarking
# ------------------------
#
# Below are benchmarks for downsampling and upsampling waveforms between
# two pairs of sampling rates. We demonstrate the performance implications
mayp777's avatar
UPDATE  
mayp777 committed
327
# that the ``lowpass_filter_width``, window type, and sample rates can
328
329
330
331
# have. Additionally, we provide a comparison against ``librosa``\ ’s
# ``kaiser_best`` and ``kaiser_fast`` using their corresponding parameters
# in ``torchaudio``.
#
moto's avatar
moto committed
332
333
334
335
336
337

print(f"torchaudio: {torchaudio.__version__}")
print(f"librosa: {librosa.__version__}")
print(f"resampy: {resampy.__version__}")

######################################################################
338
339
#

mayp777's avatar
UPDATE  
mayp777 committed
340

moto's avatar
moto committed
341
342
343
344
345
346
def benchmark_resample_functional(
    waveform,
    sample_rate,
    resample_rate,
    lowpass_filter_width=6,
    rolloff=0.99,
mayp777's avatar
UPDATE  
mayp777 committed
347
    resampling_method="sinc_interp_hann",
moto's avatar
moto committed
348
349
350
    beta=None,
    iters=5,
):
mayp777's avatar
UPDATE  
mayp777 committed
351
352
353
    return (
        timeit.timeit(
            stmt="""
moto's avatar
moto committed
354
355
356
357
358
359
360
361
362
torchaudio.functional.resample(
    waveform,
    sample_rate,
    resample_rate,
    lowpass_filter_width=lowpass_filter_width,
    rolloff=rolloff,
    resampling_method=resampling_method,
    beta=beta,
)
mayp777's avatar
UPDATE  
mayp777 committed
363
364
365
366
367
368
369
370
        """,
            setup="import torchaudio",
            number=iters,
            globals=locals(),
        )
        * 1000
        / iters
    )
moto's avatar
moto committed
371
372
373
374


######################################################################
#
375

mayp777's avatar
UPDATE  
mayp777 committed
376

moto's avatar
moto committed
377
def benchmark_resample_transforms(
378
379
380
381
382
    waveform,
    sample_rate,
    resample_rate,
    lowpass_filter_width=6,
    rolloff=0.99,
mayp777's avatar
UPDATE  
mayp777 committed
383
    resampling_method="sinc_interp_hann",
384
385
386
    beta=None,
    iters=5,
):
mayp777's avatar
UPDATE  
mayp777 committed
387
388
389
390
    return (
        timeit.timeit(
            stmt="resampler(waveform)",
            setup="""
moto's avatar
moto committed
391
392
393
394
395
396
397
398
399
400
401
402
import torchaudio

resampler = torchaudio.transforms.Resample(
    sample_rate,
    resample_rate,
    lowpass_filter_width=lowpass_filter_width,
    rolloff=rolloff,
    resampling_method=resampling_method,
    dtype=waveform.dtype,
    beta=beta,
)
resampler.to(waveform.device)
mayp777's avatar
UPDATE  
mayp777 committed
403
404
405
406
407
408
409
        """,
            number=iters,
            globals=locals(),
        )
        * 1000
        / iters
    )
410
411
412
413
414


######################################################################
#

mayp777's avatar
UPDATE  
mayp777 committed
415

moto's avatar
moto committed
416
417
418
419
420
421
422
423
def benchmark_resample_librosa(
    waveform,
    sample_rate,
    resample_rate,
    res_type=None,
    iters=5,
):
    waveform_np = waveform.squeeze().numpy()
mayp777's avatar
UPDATE  
mayp777 committed
424
425
426
    return (
        timeit.timeit(
            stmt="""
moto's avatar
moto committed
427
428
429
430
431
432
librosa.resample(
    waveform_np,
    orig_sr=sample_rate,
    target_sr=resample_rate,
    res_type=res_type,
)
mayp777's avatar
UPDATE  
mayp777 committed
433
434
435
436
437
438
439
440
        """,
            setup="import librosa",
            number=iters,
            globals=locals(),
        )
        * 1000
        / iters
    )
moto's avatar
moto committed
441
442
443
444


######################################################################
#
445

mayp777's avatar
UPDATE  
mayp777 committed
446

moto's avatar
moto committed
447
def benchmark(sample_rate, resample_rate):
448
    times, rows = [], []
moto's avatar
moto committed
449
450
451
    waveform = get_sine_sweep(sample_rate).to(torch.float32)

    args = (waveform, sample_rate, resample_rate)
452
453

    # sinc 64 zero-crossings
moto's avatar
moto committed
454
455
456
    f_time = benchmark_resample_functional(*args, lowpass_filter_width=64)
    t_time = benchmark_resample_transforms(*args, lowpass_filter_width=64)
    times.append([None, f_time, t_time])
457
458
459
    rows.append("sinc (width 64)")

    # sinc 6 zero-crossings
moto's avatar
moto committed
460
461
462
    f_time = benchmark_resample_functional(*args, lowpass_filter_width=16)
    t_time = benchmark_resample_transforms(*args, lowpass_filter_width=16)
    times.append([None, f_time, t_time])
463
464
465
    rows.append("sinc (width 16)")

    # kaiser best
moto's avatar
moto committed
466
467
468
    kwargs = {
        "lowpass_filter_width": 64,
        "rolloff": 0.9475937167399596,
mayp777's avatar
UPDATE  
mayp777 committed
469
        "resampling_method": "sinc_interp_kaiser",
moto's avatar
moto committed
470
471
472
473
474
475
        "beta": 14.769656459379492,
    }
    lib_time = benchmark_resample_librosa(*args, res_type="kaiser_best")
    f_time = benchmark_resample_functional(*args, **kwargs)
    t_time = benchmark_resample_transforms(*args, **kwargs)
    times.append([lib_time, f_time, t_time])
476
477
478
    rows.append("kaiser_best")

    # kaiser fast
moto's avatar
moto committed
479
480
481
    kwargs = {
        "lowpass_filter_width": 16,
        "rolloff": 0.85,
mayp777's avatar
UPDATE  
mayp777 committed
482
        "resampling_method": "sinc_interp_kaiser",
moto's avatar
moto committed
483
484
485
486
487
488
        "beta": 8.555504641634386,
    }
    lib_time = benchmark_resample_librosa(*args, res_type="kaiser_fast")
    f_time = benchmark_resample_functional(*args, **kwargs)
    t_time = benchmark_resample_transforms(*args, **kwargs)
    times.append([lib_time, f_time, t_time])
489
490
    rows.append("kaiser_fast")

491
    df = pd.DataFrame(times, columns=["librosa", "functional", "transforms"], index=rows)
moto's avatar
moto committed
492
493
494
495
496
497
498
499
500
    return df


######################################################################
#
def plot(df):
    print(df.round(2))
    ax = df.plot(kind="bar")
    plt.ylabel("Time Elapsed [ms]")
mayp777's avatar
UPDATE  
mayp777 committed
501
    plt.xticks(rotation=0, fontsize=10)
moto's avatar
moto committed
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
    for cont, col, color in zip(ax.containers, df.columns, mcolors.TABLEAU_COLORS):
        label = ["N/A" if v != v else str(v) for v in df[col].round(2)]
        ax.bar_label(cont, labels=label, color=color, fontweight="bold", fontsize="x-small")


######################################################################
#
# Downsample (48 -> 44.1 kHz)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~

df = benchmark(48_000, 44_100)
plot(df)

######################################################################
#
# Downsample (16 -> 8 kHz)
# ~~~~~~~~~~~~~~~~~~~~~~~~

df = benchmark(16_000, 8_000)
plot(df)

######################################################################
#
# Upsample (44.1 -> 48 kHz)
# ~~~~~~~~~~~~~~~~~~~~~~~~~

df = benchmark(44_100, 48_000)
plot(df)

######################################################################
#
# Upsample (8 -> 16 kHz)
# ~~~~~~~~~~~~~~~~~~~~~~

df = benchmark(8_000, 16_000)
plot(df)
538

moto's avatar
moto committed
539
540
541
542
543
544
545
546
547
548
######################################################################
#
# Summary
# ~~~~~~~
#
# To elaborate on the results:
#
# - a larger ``lowpass_filter_width`` results in a larger resampling kernel,
#   and therefore increases computation time for both the kernel computation
#   and convolution
mayp777's avatar
UPDATE  
mayp777 committed
549
550
# - using ``sinc_interp_kaiser`` results in longer computation times than the default
#   ``sinc_interp_hann`` because it is more complex to compute the intermediate
moto's avatar
moto committed
551
552
553
554
#   window values
# - a large GCD between the sample and resample rate will result
#   in a simplification that allows for a smaller kernel and faster kernel computation.
#