mvdr_tutorial.py 13.1 KB
Newer Older
moto's avatar
moto committed
1
"""
2
3
Speech Enhancement with MVDR Beamforming
========================================
moto's avatar
moto committed
4
5
6
7
8

**Author** `Zhaoheng Ni <zni@fb.com>`__

"""

9

moto's avatar
moto committed
10
######################################################################
11
12
# 1. Overview
# -----------
13
#
14
15
16
# This is a tutorial on applying Minimum Variance Distortionless
# Response (MVDR) beamforming to estimate enhanced speech with
# TorchAudio.
17
#
18
19
20
21
22
23
24
25
26
27
28
29
# Steps:
#
# -  Generate an ideal ratio mask (IRM) by dividing the clean/noise
#    magnitude by the mixture magnitude.
# -  Estimate power spectral density (PSD) matrices using :py:func:`torchaudio.transforms.PSD`.
# -  Estimate enhanced speech using MVDR modules
#    (:py:func:`torchaudio.transforms.SoudenMVDR` and
#    :py:func:`torchaudio.transforms.RTFMVDR`).
# -  Benchmark the two methods
#    (:py:func:`torchaudio.functional.rtf_evd` and
#    :py:func:`torchaudio.functional.rtf_power`) for computing the
#    relative transfer function (RTF) matrix of the reference microphone.
30
#
31
32
33
34
35
36
37

import torch
import torchaudio
import torchaudio.functional as F

print(torch.__version__)
print(torchaudio.__version__)
moto's avatar
moto committed
38
39
40


######################################################################
41
42
# 2. Preparation
# --------------
43
#
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

######################################################################
# 2.1. Import the packages
# ~~~~~~~~~~~~~~~~~~~~~~~~
#
# First, we install and import the necessary packages.
#
# ``mir_eval``, ``pesq``, and ``pystoi`` packages are required for
# evaluating the speech enhancement performance.
#

# When running this example in notebook, install the following packages.
# !pip3 install mir_eval
# !pip3 install pesq
# !pip3 install pystoi

from pesq import pesq
from pystoi import stoi
import mir_eval

import matplotlib.pyplot as plt
from IPython.display import Audio
from torchaudio.utils import download_asset

######################################################################
# 2.2. Download audio data
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
71
#
moto's avatar
moto committed
72
73
74
# The multi-channel audio example is selected from
# `ConferencingSpeech <https://github.com/ConferencingSpeech/ConferencingSpeech2021>`__
# dataset.
75
#
moto's avatar
moto committed
76
# The original filename is
77
#
moto's avatar
moto committed
78
#    ``SSB07200001\#noise-sound-bible-0038\#7.86_6.16_3.00_3.14_4.84_134.5285_191.7899_0.4735\#15217\#25.16333303751458\#0.2101221178590021.wav``
79
#
80
# which was generated with:
81
#
82
83
84
85
86
87
# -  ``SSB07200001.wav`` from
#    `AISHELL-3 <https://www.openslr.org/93/>`__ (Apache License
#    v.2.0)
# -  ``noise-sound-bible-0038.wav`` from
#    `MUSAN <http://www.openslr.org/17/>`__ (Attribution 4.0
#    International — CC BY 4.0)
88
#
moto's avatar
moto committed
89

90
91
92
SAMPLE_RATE = 16000
SAMPLE_CLEAN = download_asset("tutorial-assets/mvdr/clean_speech.wav")
SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav")
moto's avatar
moto committed
93
94


95
######################################################################
96
# 2.3. Helper functions
97
98
99
# ~~~~~~~~~~~~~~~~~~~~~
#

moto's avatar
moto committed
100

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def plot_spectrogram(stft, title="Spectrogram", xlim=None):
    magnitude = stft.abs()
    spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
    figure, axis = plt.subplots(1, 1)
    img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto")
    figure.suptitle(title)
    plt.colorbar(img, ax=axis)
    plt.show()


def plot_mask(mask, title="Mask", xlim=None):
    mask = mask.numpy()
    figure, axis = plt.subplots(1, 1)
    img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto")
    figure.suptitle(title)
    plt.colorbar(img, ax=axis)
    plt.show()


def si_snr(estimate, reference, epsilon=1e-8):
    estimate = estimate - estimate.mean()
    reference = reference - reference.mean()
    reference_pow = reference.pow(2).mean(axis=1, keepdim=True)
    mix_pow = (estimate * reference).mean(axis=1, keepdim=True)
    scale = mix_pow / (reference_pow + epsilon)

    reference = scale * reference
    error = estimate - reference

    reference_pow = reference.pow(2)
    error_pow = error.pow(2)

    reference_pow = reference_pow.mean(axis=1)
    error_pow = error_pow.mean(axis=1)

    si_snr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)
    return si_snr.item()
moto's avatar
moto committed
138
139


140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def generate_mixture(waveform_clean, waveform_noise, target_snr):
    power_clean_signal = waveform_clean.pow(2).mean()
    power_noise_signal = waveform_noise.pow(2).mean()
    current_snr = 10 * torch.log10(power_clean_signal / power_noise_signal)
    waveform_noise *= 10 ** (-(target_snr - current_snr) / 20)
    return waveform_clean + waveform_noise


def evaluate(estimate, reference):
    si_snr_score = si_snr(estimate, reference)
    (
        sdr,
        _,
        _,
        _,
    ) = mir_eval.separation.bss_eval_sources(reference.numpy(), estimate.numpy(), False)
    pesq_mix = pesq(SAMPLE_RATE, estimate[0].numpy(), reference[0].numpy(), "wb")
    stoi_mix = stoi(reference[0].numpy(), estimate[0].numpy(), SAMPLE_RATE, extended=False)
    print(f"SDR score: {sdr[0]}")
    print(f"Si-SNR score: {si_snr_score}")
    print(f"PESQ score: {pesq_mix}")
    print(f"STOI score: {stoi_mix}")


moto's avatar
moto committed
164
######################################################################
165
166
# 3. Generate Ideal Ratio Masks (IRMs)
# ------------------------------------
167
#
moto's avatar
moto committed
168

169

moto's avatar
moto committed
170
######################################################################
171
172
# 3.1. Load audio data
# ~~~~~~~~~~~~~~~~~~~~
173
#
moto's avatar
moto committed
174

175
176
177
waveform_clean, sr = torchaudio.load(SAMPLE_CLEAN)
waveform_noise, sr2 = torchaudio.load(SAMPLE_NOISE)
assert sr == sr2 == SAMPLE_RATE
178
179
180
# The mixture waveform is a combination of clean and noise waveforms with a desired SNR.
target_snr = 3
waveform_mix = generate_mixture(waveform_clean, waveform_noise, target_snr)
moto's avatar
moto committed
181
182
183


######################################################################
184
185
# Note: To improve computational robustness, it is recommended to represent
# the waveforms as double-precision floating point (``torch.float64`` or ``torch.double``) values.
186
#
moto's avatar
moto committed
187

188
189
190
191
waveform_mix = waveform_mix.to(torch.double)
waveform_clean = waveform_clean.to(torch.double)
waveform_noise = waveform_noise.to(torch.double)

moto's avatar
moto committed
192
193

######################################################################
194
195
# 3.2. Compute STFT coefficients
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
196
#
moto's avatar
moto committed
197

198
199
N_FFT = 1024
N_HOP = 256
moto's avatar
moto committed
200
stft = torchaudio.transforms.Spectrogram(
201
202
    n_fft=N_FFT,
    hop_length=N_HOP,
moto's avatar
moto committed
203
204
    power=None,
)
205
206
207
208
209
istft = torchaudio.transforms.InverseSpectrogram(n_fft=N_FFT, hop_length=N_HOP)

stft_mix = stft(waveform_mix)
stft_clean = stft(waveform_clean)
stft_noise = stft(waveform_noise)
moto's avatar
moto committed
210
211
212


######################################################################
213
214
# 3.2.1. Visualize mixture speech
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
215
#
216
217
218
219
220
221
# We evaluate the quality of the mixture speech or the enhanced speech
# using the following three metrics:
#
# -  signal-to-distortion ratio (SDR)
# -  scale-invariant signal-to-noise ratio (Si-SNR, or Si-SDR in some papers)
# -  Perceptual Evaluation of Speech Quality (PESQ)
moto's avatar
moto committed
222
#
223
224
# We also evaluate the intelligibility of the speech with the Short-Time Objective Intelligibility
# (STOI) metric.
moto's avatar
moto committed
225

226
plot_spectrogram(stft_mix[0], "Spectrogram of Mixture Speech (dB)")
227
evaluate(waveform_mix[0:1], waveform_clean[0:1])
228
Audio(waveform_mix[0], rate=SAMPLE_RATE)
moto's avatar
moto committed
229

230

moto's avatar
moto committed
231
######################################################################
232
233
234
# 3.2.2. Visualize clean speech
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
moto's avatar
moto committed
235

236
237
plot_spectrogram(stft_clean[0], "Spectrogram of Clean Speech (dB)")
Audio(waveform_clean[0], rate=SAMPLE_RATE)
moto's avatar
moto committed
238
239
240


######################################################################
241
242
# 3.2.3. Visualize noise
# ^^^^^^^^^^^^^^^^^^^^^^
243
#
moto's avatar
moto committed
244

245
246
247
plot_spectrogram(stft_noise[0], "Spectrogram of Noise (dB)")
Audio(waveform_noise[0], rate=SAMPLE_RATE)

moto's avatar
moto committed
248
249

######################################################################
250
251
252
253
254
# 3.3. Define the reference microphone
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# We choose the first microphone in the array as the reference channel for demonstration.
# The selection of the reference channel may depend on the design of the microphone array.
255
#
256
257
258
259
# You can also apply an end-to-end neural network which estimates both the reference channel and
# the PSD matrices, then obtains the enhanced STFT coefficients by the MVDR module.

REFERENCE_CHANNEL = 0
moto's avatar
moto committed
260
261
262


######################################################################
263
264
# 3.4. Compute IRMs
# ~~~~~~~~~~~~~~~~~
265
266
#

moto's avatar
moto committed
267

268
269
270
271
272
273
def get_irms(stft_clean, stft_noise):
    mag_clean = stft_clean.abs() ** 2
    mag_noise = stft_noise.abs() ** 2
    irm_speech = mag_clean / (mag_clean + mag_noise)
    irm_noise = mag_noise / (mag_clean + mag_noise)
    return irm_speech[REFERENCE_CHANNEL], irm_noise[REFERENCE_CHANNEL]
moto's avatar
moto committed
274
275


276
irm_speech, irm_noise = get_irms(stft_clean, stft_noise)
moto's avatar
moto committed
277

278

moto's avatar
moto committed
279
######################################################################
280
281
# 3.4.1. Visualize IRM of target speech
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
282
#
moto's avatar
moto committed
283

284
plot_mask(irm_speech, "IRM of the Target Speech")
moto's avatar
moto committed
285
286
287


######################################################################
288
289
# 3.4.2. Visualize IRM of noise
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
290
#
moto's avatar
moto committed
291

292
plot_mask(irm_noise, "IRM of the Noise")
moto's avatar
moto committed
293
294

######################################################################
295
296
# 4. Compute PSD matrices
# -----------------------
297
#
298
299
300
# :py:func:`torchaudio.transforms.PSD` computes the time-invariant PSD matrix given
# the multi-channel complex-valued STFT coefficients  of the mixture speech
# and the time-frequency mask.
301
#
302
# The shape of the PSD matrix is `(..., freq, channel, channel)`.
moto's avatar
moto committed
303

304
psd_transform = torchaudio.transforms.PSD()
moto's avatar
moto committed
305

306
307
psd_speech = psd_transform(stft_mix, irm_speech)
psd_noise = psd_transform(stft_mix, irm_noise)
moto's avatar
moto committed
308
309
310


######################################################################
311
312
# 5. Beamforming using SoudenMVDR
# -------------------------------
313
#
moto's avatar
moto committed
314
315
316


######################################################################
317
318
319
320
321
322
# 5.1. Apply beamforming
# ~~~~~~~~~~~~~~~~~~~~~~
#
# :py:func:`torchaudio.transforms.SoudenMVDR` takes the multi-channel
# complexed-valued STFT coefficients of the mixture speech, PSD matrices of
# target speech and noise, and the reference channel inputs.
323
#
324
325
326
327
328
329
330
331
# The output is a single-channel complex-valued STFT coefficients of the enhanced speech.
# We can then obtain the enhanced waveform by passing this output to the
# :py:func:`torchaudio.transforms.InverseSpectrogram` module.

mvdr_transform = torchaudio.transforms.SoudenMVDR()
stft_souden = mvdr_transform(stft_mix, psd_speech, psd_noise, reference_channel=REFERENCE_CHANNEL)
waveform_souden = istft(stft_souden, length=waveform_mix.shape[-1])

moto's avatar
moto committed
332
333

######################################################################
334
335
# 5.2. Result for SoudenMVDR
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
336
#
moto's avatar
moto committed
337

338
339
plot_spectrogram(stft_souden, "Enhanced Spectrogram by SoudenMVDR (dB)")
waveform_souden = waveform_souden.reshape(1, -1)
340
evaluate(waveform_souden, waveform_clean[0:1])
341
342
Audio(waveform_souden, rate=SAMPLE_RATE)

moto's avatar
moto committed
343
344

######################################################################
345
346
# 6. Beamforming using RTFMVDR
# ----------------------------
347
#
moto's avatar
moto committed
348
349
350


######################################################################
351
352
353
354
355
356
357
358
# 6.1. Compute RTF
# ~~~~~~~~~~~~~~~~
#
# TorchAudio offers two methods for computing the RTF matrix of a
# target speech:
#
# -  :py:func:`torchaudio.functional.rtf_evd`, which applies eigenvalue
#    decomposition to the PSD matrix of target speech to get the RTF matrix.
359
#
360
361
362
363
364
365
# -  :py:func:`torchaudio.functional.rtf_power`, which applies the power iteration
#    method. You can specify the number of iterations with argument ``n_iter``.
#

rtf_evd = F.rtf_evd(psd_speech)
rtf_power = F.rtf_power(psd_speech, psd_noise, reference_channel=REFERENCE_CHANNEL)
moto's avatar
moto committed
366
367
368


######################################################################
369
370
371
372
373
374
# 6.2. Apply beamforming
# ~~~~~~~~~~~~~~~~~~~~~~
#
# :py:func:`torchaudio.transforms.RTFMVDR` takes the multi-channel
# complexed-valued STFT coefficients of the mixture speech, RTF matrix of target speech,
# PSD matrix of noise, and the reference channel inputs.
375
#
376
377
378
379
380
381
382
383
384
385
386
387
388
# The output is a single-channel complex-valued STFT coefficients of the enhanced speech.
# We can then obtain the enhanced waveform by passing this output to the
# :py:func:`torchaudio.transforms.InverseSpectrogram` module.

mvdr_transform = torchaudio.transforms.RTFMVDR()

# compute the enhanced speech based on F.rtf_evd
stft_rtf_evd = mvdr_transform(stft_mix, rtf_evd, psd_noise, reference_channel=REFERENCE_CHANNEL)
waveform_rtf_evd = istft(stft_rtf_evd, length=waveform_mix.shape[-1])

# compute the enhanced speech based on F.rtf_power
stft_rtf_power = mvdr_transform(stft_mix, rtf_power, psd_noise, reference_channel=REFERENCE_CHANNEL)
waveform_rtf_power = istft(stft_rtf_power, length=waveform_mix.shape[-1])
moto's avatar
moto committed
389
390
391


######################################################################
392
393
# 6.3. Result for RTFMVDR with `rtf_evd`
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
394
#
moto's avatar
moto committed
395

396
397
plot_spectrogram(stft_rtf_evd, "Enhanced Spectrogram by RTFMVDR and F.rtf_evd (dB)")
waveform_rtf_evd = waveform_rtf_evd.reshape(1, -1)
398
evaluate(waveform_rtf_evd, waveform_clean[0:1])
399
400
Audio(waveform_rtf_evd, rate=SAMPLE_RATE)

moto's avatar
moto committed
401
402

######################################################################
403
404
# 6.4. Result for RTFMVDR with `rtf_power`
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
405
#
moto's avatar
moto committed
406

407
plot_spectrogram(stft_rtf_power, "Enhanced Spectrogram by RTFMVDR and F.rtf_power (dB)")
408
waveform_rtf_power = waveform_rtf_power.reshape(1, -1)
409
evaluate(waveform_rtf_power, waveform_clean[0:1])
410
Audio(waveform_rtf_power, rate=SAMPLE_RATE)