mvdr_tutorial.py 13.1 KB
Newer Older
moto's avatar
moto committed
1
"""
2
3
Speech Enhancement with MVDR Beamforming
========================================
moto's avatar
moto committed
4
5
6
7
8

**Author** `Zhaoheng Ni <zni@fb.com>`__

"""

9

moto's avatar
moto committed
10
######################################################################
11
12
# 1. Overview
# -----------
13
#
14
15
16
# This is a tutorial on applying Minimum Variance Distortionless
# Response (MVDR) beamforming to estimate enhanced speech with
# TorchAudio.
17
#
18
19
20
21
22
23
24
25
26
27
28
29
# Steps:
#
# -  Generate an ideal ratio mask (IRM) by dividing the clean/noise
#    magnitude by the mixture magnitude.
# -  Estimate power spectral density (PSD) matrices using :py:func:`torchaudio.transforms.PSD`.
# -  Estimate enhanced speech using MVDR modules
#    (:py:func:`torchaudio.transforms.SoudenMVDR` and
#    :py:func:`torchaudio.transforms.RTFMVDR`).
# -  Benchmark the two methods
#    (:py:func:`torchaudio.functional.rtf_evd` and
#    :py:func:`torchaudio.functional.rtf_power`) for computing the
#    relative transfer function (RTF) matrix of the reference microphone.
30
#
31
32
33
34
35
36
37

import torch
import torchaudio
import torchaudio.functional as F

print(torch.__version__)
print(torchaudio.__version__)
moto's avatar
moto committed
38
39
40


######################################################################
41
42
# 2. Preparation
# --------------
43
#
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

######################################################################
# 2.1. Import the packages
# ~~~~~~~~~~~~~~~~~~~~~~~~
#
# First, we install and import the necessary packages.
#
# ``mir_eval``, ``pesq``, and ``pystoi`` packages are required for
# evaluating the speech enhancement performance.
#

# When running this example in notebook, install the following packages.
# !pip3 install mir_eval
# !pip3 install pesq
# !pip3 install pystoi

from pesq import pesq
from pystoi import stoi
import mir_eval

import matplotlib.pyplot as plt
from IPython.display import Audio
from torchaudio.utils import download_asset

######################################################################
# 2.2. Download audio data
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
71
#
moto's avatar
moto committed
72
73
74
# The multi-channel audio example is selected from
# `ConferencingSpeech <https://github.com/ConferencingSpeech/ConferencingSpeech2021>`__
# dataset.
75
#
moto's avatar
moto committed
76
# The original filename is
77
#
moto's avatar
moto committed
78
#    ``SSB07200001\#noise-sound-bible-0038\#7.86_6.16_3.00_3.14_4.84_134.5285_191.7899_0.4735\#15217\#25.16333303751458\#0.2101221178590021.wav``
79
#
80
# which was generated with:
81
#
82
83
84
85
86
87
# -  ``SSB07200001.wav`` from
#    `AISHELL-3 <https://www.openslr.org/93/>`__ (Apache License
#    v.2.0)
# -  ``noise-sound-bible-0038.wav`` from
#    `MUSAN <http://www.openslr.org/17/>`__ (Attribution 4.0
#    International — CC BY 4.0)
88
#
moto's avatar
moto committed
89

90
91
92
SAMPLE_RATE = 16000
SAMPLE_CLEAN = download_asset("tutorial-assets/mvdr/clean_speech.wav")
SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav")
moto's avatar
moto committed
93
94


95
######################################################################
96
# 2.3. Helper functions
97
98
99
# ~~~~~~~~~~~~~~~~~~~~~
#

moto's avatar
moto committed
100

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def plot_spectrogram(stft, title="Spectrogram", xlim=None):
    magnitude = stft.abs()
    spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
    figure, axis = plt.subplots(1, 1)
    img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto")
    figure.suptitle(title)
    plt.colorbar(img, ax=axis)
    plt.show()


def plot_mask(mask, title="Mask", xlim=None):
    mask = mask.numpy()
    figure, axis = plt.subplots(1, 1)
    img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto")
    figure.suptitle(title)
    plt.colorbar(img, ax=axis)
    plt.show()


def si_snr(estimate, reference, epsilon=1e-8):
    estimate = estimate - estimate.mean()
    reference = reference - reference.mean()
    reference_pow = reference.pow(2).mean(axis=1, keepdim=True)
    mix_pow = (estimate * reference).mean(axis=1, keepdim=True)
    scale = mix_pow / (reference_pow + epsilon)

    reference = scale * reference
    error = estimate - reference

    reference_pow = reference.pow(2)
    error_pow = error.pow(2)

    reference_pow = reference_pow.mean(axis=1)
    error_pow = error_pow.mean(axis=1)

    si_snr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)
    return si_snr.item()
moto's avatar
moto committed
138
139


140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def generate_mixture(waveform_clean, waveform_noise, target_snr):
    power_clean_signal = waveform_clean.pow(2).mean()
    power_noise_signal = waveform_noise.pow(2).mean()
    current_snr = 10 * torch.log10(power_clean_signal / power_noise_signal)
    waveform_noise *= 10 ** (-(target_snr - current_snr) / 20)
    return waveform_clean + waveform_noise


def evaluate(estimate, reference):
    si_snr_score = si_snr(estimate, reference)
    (
        sdr,
        _,
        _,
        _,
    ) = mir_eval.separation.bss_eval_sources(reference.numpy(), estimate.numpy(), False)
    pesq_mix = pesq(SAMPLE_RATE, estimate[0].numpy(), reference[0].numpy(), "wb")
    stoi_mix = stoi(reference[0].numpy(), estimate[0].numpy(), SAMPLE_RATE, extended=False)
    print(f"SDR score: {sdr[0]}")
    print(f"Si-SNR score: {si_snr_score}")
    print(f"PESQ score: {pesq_mix}")
    print(f"STOI score: {stoi_mix}")


moto's avatar
moto committed
164
######################################################################
165
166
# 3. Generate Ideal Ratio Masks (IRMs)
# ------------------------------------
167
#
moto's avatar
moto committed
168

169

moto's avatar
moto committed
170
######################################################################
171
172
# 3.1. Load audio data
# ~~~~~~~~~~~~~~~~~~~~
173
#
moto's avatar
moto committed
174

175
176
177
waveform_clean, sr = torchaudio.load(SAMPLE_CLEAN)
waveform_noise, sr2 = torchaudio.load(SAMPLE_NOISE)
assert sr == sr2 == SAMPLE_RATE
178
179
180
# The mixture waveform is a combination of clean and noise waveforms with a desired SNR.
target_snr = 3
waveform_mix = generate_mixture(waveform_clean, waveform_noise, target_snr)
moto's avatar
moto committed
181
182
183


######################################################################
184
185
# Note: To improve computational robustness, it is recommended to represent
# the waveforms as double-precision floating point (``torch.float64`` or ``torch.double``) values.
186
#
moto's avatar
moto committed
187

188
189
190
191
waveform_mix = waveform_mix.to(torch.double)
waveform_clean = waveform_clean.to(torch.double)
waveform_noise = waveform_noise.to(torch.double)

moto's avatar
moto committed
192
193

######################################################################
194
195
# 3.2. Compute STFT coefficients
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
196
#
moto's avatar
moto committed
197

198
199
N_FFT = 1024
N_HOP = 256
moto's avatar
moto committed
200
stft = torchaudio.transforms.Spectrogram(
201
202
    n_fft=N_FFT,
    hop_length=N_HOP,
moto's avatar
moto committed
203
204
    power=None,
)
205
206
207
208
209
istft = torchaudio.transforms.InverseSpectrogram(n_fft=N_FFT, hop_length=N_HOP)

stft_mix = stft(waveform_mix)
stft_clean = stft(waveform_clean)
stft_noise = stft(waveform_noise)
moto's avatar
moto committed
210
211
212


######################################################################
213
214
# 3.2.1. Visualize mixture speech
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
215
#
216
217
218
219
220
221
222
223
# We evaluate the quality of the mixture speech or the enhanced speech
# using the following three metrics:
#
# -  signal-to-distortion ratio (SDR)
# -  scale-invariant signal-to-noise ratio (Si-SNR, or Si-SDR in some papers)
# -  Perceptual Evaluation of Speech Quality (PESQ)
# We also evaluate the intelligibility of the speech with the Short-Time Objective Intelligibility
# (STOI) metric.
moto's avatar
moto committed
224

225
plot_spectrogram(stft_mix[0], "Spectrogram of Mixture Speech (dB)")
226
evaluate(waveform_mix[0:1], waveform_clean[0:1])
227
Audio(waveform_mix[0], rate=SAMPLE_RATE)
moto's avatar
moto committed
228

229

moto's avatar
moto committed
230
######################################################################
231
232
233
# 3.2.2. Visualize clean speech
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
moto's avatar
moto committed
234

235
236
plot_spectrogram(stft_clean[0], "Spectrogram of Clean Speech (dB)")
Audio(waveform_clean[0], rate=SAMPLE_RATE)
moto's avatar
moto committed
237
238
239


######################################################################
240
241
# 3.2.3. Visualize noise
# ^^^^^^^^^^^^^^^^^^^^^^
242
#
moto's avatar
moto committed
243

244
245
246
plot_spectrogram(stft_noise[0], "Spectrogram of Noise (dB)")
Audio(waveform_noise[0], rate=SAMPLE_RATE)

moto's avatar
moto committed
247
248

######################################################################
249
250
251
252
253
# 3.3. Define the reference microphone
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# We choose the first microphone in the array as the reference channel for demonstration.
# The selection of the reference channel may depend on the design of the microphone array.
254
#
255
256
257
258
# You can also apply an end-to-end neural network which estimates both the reference channel and
# the PSD matrices, then obtains the enhanced STFT coefficients by the MVDR module.

REFERENCE_CHANNEL = 0
moto's avatar
moto committed
259
260
261


######################################################################
262
263
# 3.4. Compute IRMs
# ~~~~~~~~~~~~~~~~~
264
265
#

moto's avatar
moto committed
266

267
268
269
270
271
272
def get_irms(stft_clean, stft_noise):
    mag_clean = stft_clean.abs() ** 2
    mag_noise = stft_noise.abs() ** 2
    irm_speech = mag_clean / (mag_clean + mag_noise)
    irm_noise = mag_noise / (mag_clean + mag_noise)
    return irm_speech[REFERENCE_CHANNEL], irm_noise[REFERENCE_CHANNEL]
moto's avatar
moto committed
273
274


275
irm_speech, irm_noise = get_irms(stft_clean, stft_noise)
moto's avatar
moto committed
276

277

moto's avatar
moto committed
278
######################################################################
279
280
# 3.4.1. Visualize IRM of target speech
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
281
#
moto's avatar
moto committed
282

283
plot_mask(irm_speech, "IRM of the Target Speech")
moto's avatar
moto committed
284
285
286


######################################################################
287
288
# 3.4.2. Visualize IRM of noise
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
289
#
moto's avatar
moto committed
290

291
plot_mask(irm_noise, "IRM of the Noise")
moto's avatar
moto committed
292
293

######################################################################
294
295
# 4. Compute PSD matrices
# -----------------------
296
#
297
298
299
# :py:func:`torchaudio.transforms.PSD` computes the time-invariant PSD matrix given
# the multi-channel complex-valued STFT coefficients  of the mixture speech
# and the time-frequency mask.
300
#
301
# The shape of the PSD matrix is `(..., freq, channel, channel)`.
moto's avatar
moto committed
302

303
psd_transform = torchaudio.transforms.PSD()
moto's avatar
moto committed
304

305
306
psd_speech = psd_transform(stft_mix, irm_speech)
psd_noise = psd_transform(stft_mix, irm_noise)
moto's avatar
moto committed
307
308
309


######################################################################
310
311
# 5. Beamforming using SoudenMVDR
# -------------------------------
312
#
moto's avatar
moto committed
313
314
315


######################################################################
316
317
318
319
320
321
# 5.1. Apply beamforming
# ~~~~~~~~~~~~~~~~~~~~~~
#
# :py:func:`torchaudio.transforms.SoudenMVDR` takes the multi-channel
# complexed-valued STFT coefficients of the mixture speech, PSD matrices of
# target speech and noise, and the reference channel inputs.
322
#
323
324
325
326
327
328
329
330
# The output is a single-channel complex-valued STFT coefficients of the enhanced speech.
# We can then obtain the enhanced waveform by passing this output to the
# :py:func:`torchaudio.transforms.InverseSpectrogram` module.

mvdr_transform = torchaudio.transforms.SoudenMVDR()
stft_souden = mvdr_transform(stft_mix, psd_speech, psd_noise, reference_channel=REFERENCE_CHANNEL)
waveform_souden = istft(stft_souden, length=waveform_mix.shape[-1])

moto's avatar
moto committed
331
332

######################################################################
333
334
# 5.2. Result for SoudenMVDR
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
335
#
moto's avatar
moto committed
336

337
338
plot_spectrogram(stft_souden, "Enhanced Spectrogram by SoudenMVDR (dB)")
waveform_souden = waveform_souden.reshape(1, -1)
339
evaluate(waveform_souden, waveform_clean[0:1])
340
341
Audio(waveform_souden, rate=SAMPLE_RATE)

moto's avatar
moto committed
342
343

######################################################################
344
345
# 6. Beamforming using RTFMVDR
# ----------------------------
346
#
moto's avatar
moto committed
347
348
349


######################################################################
350
351
352
353
354
355
356
357
# 6.1. Compute RTF
# ~~~~~~~~~~~~~~~~
#
# TorchAudio offers two methods for computing the RTF matrix of a
# target speech:
#
# -  :py:func:`torchaudio.functional.rtf_evd`, which applies eigenvalue
#    decomposition to the PSD matrix of target speech to get the RTF matrix.
358
#
359
360
361
362
363
364
# -  :py:func:`torchaudio.functional.rtf_power`, which applies the power iteration
#    method. You can specify the number of iterations with argument ``n_iter``.
#

rtf_evd = F.rtf_evd(psd_speech)
rtf_power = F.rtf_power(psd_speech, psd_noise, reference_channel=REFERENCE_CHANNEL)
moto's avatar
moto committed
365
366
367


######################################################################
368
369
370
371
372
373
# 6.2. Apply beamforming
# ~~~~~~~~~~~~~~~~~~~~~~
#
# :py:func:`torchaudio.transforms.RTFMVDR` takes the multi-channel
# complexed-valued STFT coefficients of the mixture speech, RTF matrix of target speech,
# PSD matrix of noise, and the reference channel inputs.
374
#
375
376
377
378
379
380
381
382
383
384
385
386
387
# The output is a single-channel complex-valued STFT coefficients of the enhanced speech.
# We can then obtain the enhanced waveform by passing this output to the
# :py:func:`torchaudio.transforms.InverseSpectrogram` module.

mvdr_transform = torchaudio.transforms.RTFMVDR()

# compute the enhanced speech based on F.rtf_evd
stft_rtf_evd = mvdr_transform(stft_mix, rtf_evd, psd_noise, reference_channel=REFERENCE_CHANNEL)
waveform_rtf_evd = istft(stft_rtf_evd, length=waveform_mix.shape[-1])

# compute the enhanced speech based on F.rtf_power
stft_rtf_power = mvdr_transform(stft_mix, rtf_power, psd_noise, reference_channel=REFERENCE_CHANNEL)
waveform_rtf_power = istft(stft_rtf_power, length=waveform_mix.shape[-1])
moto's avatar
moto committed
388
389
390


######################################################################
391
392
# 6.3. Result for RTFMVDR with `rtf_evd`
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
393
#
moto's avatar
moto committed
394

395
396
plot_spectrogram(stft_rtf_evd, "Enhanced Spectrogram by RTFMVDR and F.rtf_evd (dB)")
waveform_rtf_evd = waveform_rtf_evd.reshape(1, -1)
397
evaluate(waveform_rtf_evd, waveform_clean[0:1])
398
399
Audio(waveform_rtf_evd, rate=SAMPLE_RATE)

moto's avatar
moto committed
400
401

######################################################################
402
403
# 6.4. Result for RTFMVDR with `rtf_power`
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
404
#
moto's avatar
moto committed
405

406
plot_spectrogram(stft_rtf_power, "Enhanced Spectrogram by RTFMVDR and F.rtf_power (dB)")
407
waveform_rtf_power = waveform_rtf_power.reshape(1, -1)
408
evaluate(waveform_rtf_power, waveform_clean[0:1])
409
Audio(waveform_rtf_power, rate=SAMPLE_RATE)