mvdr_tutorial.py 13 KB
Newer Older
moto's avatar
moto committed
1
"""
2
3
Speech Enhancement with MVDR Beamforming
========================================
moto's avatar
moto committed
4

5
**Author**: `Zhaoheng Ni <zni@meta.com>`__
moto's avatar
moto committed
6
7
8

"""

9

moto's avatar
moto committed
10
######################################################################
11
12
# 1. Overview
# -----------
13
#
14
15
16
# This is a tutorial on applying Minimum Variance Distortionless
# Response (MVDR) beamforming to estimate enhanced speech with
# TorchAudio.
17
#
18
19
20
21
22
23
24
25
26
27
28
29
# Steps:
#
# -  Generate an ideal ratio mask (IRM) by dividing the clean/noise
#    magnitude by the mixture magnitude.
# -  Estimate power spectral density (PSD) matrices using :py:func:`torchaudio.transforms.PSD`.
# -  Estimate enhanced speech using MVDR modules
#    (:py:func:`torchaudio.transforms.SoudenMVDR` and
#    :py:func:`torchaudio.transforms.RTFMVDR`).
# -  Benchmark the two methods
#    (:py:func:`torchaudio.functional.rtf_evd` and
#    :py:func:`torchaudio.functional.rtf_power`) for computing the
#    relative transfer function (RTF) matrix of the reference microphone.
30
#
31
32
33
34
35
36
37

import torch
import torchaudio
import torchaudio.functional as F

print(torch.__version__)
print(torchaudio.__version__)
moto's avatar
moto committed
38
39


mayp777's avatar
UPDATE  
mayp777 committed
40
41
42
43
import matplotlib.pyplot as plt
import mir_eval
from IPython.display import Audio

moto's avatar
moto committed
44
######################################################################
45
46
# 2. Preparation
# --------------
47
#
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

######################################################################
# 2.1. Import the packages
# ~~~~~~~~~~~~~~~~~~~~~~~~
#
# First, we install and import the necessary packages.
#
# ``mir_eval``, ``pesq``, and ``pystoi`` packages are required for
# evaluating the speech enhancement performance.
#

# When running this example in notebook, install the following packages.
# !pip3 install mir_eval
# !pip3 install pesq
# !pip3 install pystoi

from pesq import pesq
from pystoi import stoi
from torchaudio.utils import download_asset

######################################################################
# 2.2. Download audio data
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
71
#
moto's avatar
moto committed
72
73
74
# The multi-channel audio example is selected from
# `ConferencingSpeech <https://github.com/ConferencingSpeech/ConferencingSpeech2021>`__
# dataset.
75
#
moto's avatar
moto committed
76
# The original filename is
77
#
moto's avatar
moto committed
78
#    ``SSB07200001\#noise-sound-bible-0038\#7.86_6.16_3.00_3.14_4.84_134.5285_191.7899_0.4735\#15217\#25.16333303751458\#0.2101221178590021.wav``
79
#
80
# which was generated with:
81
#
82
83
84
85
86
87
# -  ``SSB07200001.wav`` from
#    `AISHELL-3 <https://www.openslr.org/93/>`__ (Apache License
#    v.2.0)
# -  ``noise-sound-bible-0038.wav`` from
#    `MUSAN <http://www.openslr.org/17/>`__ (Attribution 4.0
#    International — CC BY 4.0)
88
#
moto's avatar
moto committed
89

90
91
92
SAMPLE_RATE = 16000
SAMPLE_CLEAN = download_asset("tutorial-assets/mvdr/clean_speech.wav")
SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav")
moto's avatar
moto committed
93
94


95
######################################################################
96
# 2.3. Helper functions
97
98
99
# ~~~~~~~~~~~~~~~~~~~~~
#

moto's avatar
moto committed
100

mayp777's avatar
UPDATE  
mayp777 committed
101
def plot_spectrogram(stft, title="Spectrogram"):
102
103
104
105
    magnitude = stft.abs()
    spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
    figure, axis = plt.subplots(1, 1)
    img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto")
mayp777's avatar
UPDATE  
mayp777 committed
106
    axis.set_title(title)
107
108
109
    plt.colorbar(img, ax=axis)


mayp777's avatar
UPDATE  
mayp777 committed
110
def plot_mask(mask, title="Mask"):
111
112
113
    mask = mask.numpy()
    figure, axis = plt.subplots(1, 1)
    img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto")
mayp777's avatar
UPDATE  
mayp777 committed
114
    axis.set_title(title)
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    plt.colorbar(img, ax=axis)


def si_snr(estimate, reference, epsilon=1e-8):
    estimate = estimate - estimate.mean()
    reference = reference - reference.mean()
    reference_pow = reference.pow(2).mean(axis=1, keepdim=True)
    mix_pow = (estimate * reference).mean(axis=1, keepdim=True)
    scale = mix_pow / (reference_pow + epsilon)

    reference = scale * reference
    error = estimate - reference

    reference_pow = reference.pow(2)
    error_pow = error.pow(2)

    reference_pow = reference_pow.mean(axis=1)
    error_pow = error_pow.mean(axis=1)

    si_snr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)
    return si_snr.item()
moto's avatar
moto committed
136
137


138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def generate_mixture(waveform_clean, waveform_noise, target_snr):
    power_clean_signal = waveform_clean.pow(2).mean()
    power_noise_signal = waveform_noise.pow(2).mean()
    current_snr = 10 * torch.log10(power_clean_signal / power_noise_signal)
    waveform_noise *= 10 ** (-(target_snr - current_snr) / 20)
    return waveform_clean + waveform_noise


def evaluate(estimate, reference):
    si_snr_score = si_snr(estimate, reference)
    (
        sdr,
        _,
        _,
        _,
    ) = mir_eval.separation.bss_eval_sources(reference.numpy(), estimate.numpy(), False)
    pesq_mix = pesq(SAMPLE_RATE, estimate[0].numpy(), reference[0].numpy(), "wb")
    stoi_mix = stoi(reference[0].numpy(), estimate[0].numpy(), SAMPLE_RATE, extended=False)
    print(f"SDR score: {sdr[0]}")
    print(f"Si-SNR score: {si_snr_score}")
    print(f"PESQ score: {pesq_mix}")
    print(f"STOI score: {stoi_mix}")


moto's avatar
moto committed
162
######################################################################
163
164
# 3. Generate Ideal Ratio Masks (IRMs)
# ------------------------------------
165
#
moto's avatar
moto committed
166

167

moto's avatar
moto committed
168
######################################################################
169
170
# 3.1. Load audio data
# ~~~~~~~~~~~~~~~~~~~~
171
#
moto's avatar
moto committed
172

173
174
175
waveform_clean, sr = torchaudio.load(SAMPLE_CLEAN)
waveform_noise, sr2 = torchaudio.load(SAMPLE_NOISE)
assert sr == sr2 == SAMPLE_RATE
176
177
178
# The mixture waveform is a combination of clean and noise waveforms with a desired SNR.
target_snr = 3
waveform_mix = generate_mixture(waveform_clean, waveform_noise, target_snr)
moto's avatar
moto committed
179
180
181


######################################################################
182
183
# Note: To improve computational robustness, it is recommended to represent
# the waveforms as double-precision floating point (``torch.float64`` or ``torch.double``) values.
184
#
moto's avatar
moto committed
185

186
187
188
189
waveform_mix = waveform_mix.to(torch.double)
waveform_clean = waveform_clean.to(torch.double)
waveform_noise = waveform_noise.to(torch.double)

moto's avatar
moto committed
190
191

######################################################################
192
193
# 3.2. Compute STFT coefficients
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
#
moto's avatar
moto committed
195

196
197
N_FFT = 1024
N_HOP = 256
moto's avatar
moto committed
198
stft = torchaudio.transforms.Spectrogram(
199
200
    n_fft=N_FFT,
    hop_length=N_HOP,
moto's avatar
moto committed
201
202
    power=None,
)
203
204
205
206
207
istft = torchaudio.transforms.InverseSpectrogram(n_fft=N_FFT, hop_length=N_HOP)

stft_mix = stft(waveform_mix)
stft_clean = stft(waveform_clean)
stft_noise = stft(waveform_noise)
moto's avatar
moto committed
208
209
210


######################################################################
211
212
# 3.2.1. Visualize mixture speech
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
213
#
214
215
216
217
218
219
# We evaluate the quality of the mixture speech or the enhanced speech
# using the following three metrics:
#
# -  signal-to-distortion ratio (SDR)
# -  scale-invariant signal-to-noise ratio (Si-SNR, or Si-SDR in some papers)
# -  Perceptual Evaluation of Speech Quality (PESQ)
moto's avatar
moto committed
220
#
221
222
# We also evaluate the intelligibility of the speech with the Short-Time Objective Intelligibility
# (STOI) metric.
moto's avatar
moto committed
223

224
plot_spectrogram(stft_mix[0], "Spectrogram of Mixture Speech (dB)")
225
evaluate(waveform_mix[0:1], waveform_clean[0:1])
226
Audio(waveform_mix[0], rate=SAMPLE_RATE)
moto's avatar
moto committed
227

228

moto's avatar
moto committed
229
######################################################################
230
231
232
# 3.2.2. Visualize clean speech
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
moto's avatar
moto committed
233

234
235
plot_spectrogram(stft_clean[0], "Spectrogram of Clean Speech (dB)")
Audio(waveform_clean[0], rate=SAMPLE_RATE)
moto's avatar
moto committed
236
237
238


######################################################################
239
240
# 3.2.3. Visualize noise
# ^^^^^^^^^^^^^^^^^^^^^^
241
#
moto's avatar
moto committed
242

243
244
245
plot_spectrogram(stft_noise[0], "Spectrogram of Noise (dB)")
Audio(waveform_noise[0], rate=SAMPLE_RATE)

moto's avatar
moto committed
246
247

######################################################################
248
249
250
251
252
# 3.3. Define the reference microphone
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# We choose the first microphone in the array as the reference channel for demonstration.
# The selection of the reference channel may depend on the design of the microphone array.
253
#
254
255
256
257
# You can also apply an end-to-end neural network which estimates both the reference channel and
# the PSD matrices, then obtains the enhanced STFT coefficients by the MVDR module.

REFERENCE_CHANNEL = 0
moto's avatar
moto committed
258
259
260


######################################################################
261
262
# 3.4. Compute IRMs
# ~~~~~~~~~~~~~~~~~
263
264
#

moto's avatar
moto committed
265

266
267
268
269
270
271
def get_irms(stft_clean, stft_noise):
    mag_clean = stft_clean.abs() ** 2
    mag_noise = stft_noise.abs() ** 2
    irm_speech = mag_clean / (mag_clean + mag_noise)
    irm_noise = mag_noise / (mag_clean + mag_noise)
    return irm_speech[REFERENCE_CHANNEL], irm_noise[REFERENCE_CHANNEL]
moto's avatar
moto committed
272
273


274
irm_speech, irm_noise = get_irms(stft_clean, stft_noise)
moto's avatar
moto committed
275

276

moto's avatar
moto committed
277
######################################################################
278
279
# 3.4.1. Visualize IRM of target speech
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
280
#
moto's avatar
moto committed
281

282
plot_mask(irm_speech, "IRM of the Target Speech")
moto's avatar
moto committed
283
284
285


######################################################################
286
287
# 3.4.2. Visualize IRM of noise
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
288
#
moto's avatar
moto committed
289

290
plot_mask(irm_noise, "IRM of the Noise")
moto's avatar
moto committed
291
292

######################################################################
293
294
# 4. Compute PSD matrices
# -----------------------
295
#
296
297
298
# :py:func:`torchaudio.transforms.PSD` computes the time-invariant PSD matrix given
# the multi-channel complex-valued STFT coefficients  of the mixture speech
# and the time-frequency mask.
299
#
300
# The shape of the PSD matrix is `(..., freq, channel, channel)`.
moto's avatar
moto committed
301

302
psd_transform = torchaudio.transforms.PSD()
moto's avatar
moto committed
303

304
305
psd_speech = psd_transform(stft_mix, irm_speech)
psd_noise = psd_transform(stft_mix, irm_noise)
moto's avatar
moto committed
306
307
308


######################################################################
309
310
# 5. Beamforming using SoudenMVDR
# -------------------------------
311
#
moto's avatar
moto committed
312
313
314


######################################################################
315
316
317
318
319
320
# 5.1. Apply beamforming
# ~~~~~~~~~~~~~~~~~~~~~~
#
# :py:func:`torchaudio.transforms.SoudenMVDR` takes the multi-channel
# complexed-valued STFT coefficients of the mixture speech, PSD matrices of
# target speech and noise, and the reference channel inputs.
321
#
322
323
324
325
326
327
328
329
# The output is a single-channel complex-valued STFT coefficients of the enhanced speech.
# We can then obtain the enhanced waveform by passing this output to the
# :py:func:`torchaudio.transforms.InverseSpectrogram` module.

mvdr_transform = torchaudio.transforms.SoudenMVDR()
stft_souden = mvdr_transform(stft_mix, psd_speech, psd_noise, reference_channel=REFERENCE_CHANNEL)
waveform_souden = istft(stft_souden, length=waveform_mix.shape[-1])

moto's avatar
moto committed
330
331

######################################################################
332
333
# 5.2. Result for SoudenMVDR
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
334
#
moto's avatar
moto committed
335

336
337
plot_spectrogram(stft_souden, "Enhanced Spectrogram by SoudenMVDR (dB)")
waveform_souden = waveform_souden.reshape(1, -1)
338
evaluate(waveform_souden, waveform_clean[0:1])
339
340
Audio(waveform_souden, rate=SAMPLE_RATE)

moto's avatar
moto committed
341
342

######################################################################
343
344
# 6. Beamforming using RTFMVDR
# ----------------------------
345
#
moto's avatar
moto committed
346
347
348


######################################################################
349
350
351
352
353
354
355
356
# 6.1. Compute RTF
# ~~~~~~~~~~~~~~~~
#
# TorchAudio offers two methods for computing the RTF matrix of a
# target speech:
#
# -  :py:func:`torchaudio.functional.rtf_evd`, which applies eigenvalue
#    decomposition to the PSD matrix of target speech to get the RTF matrix.
357
#
358
359
360
361
362
363
# -  :py:func:`torchaudio.functional.rtf_power`, which applies the power iteration
#    method. You can specify the number of iterations with argument ``n_iter``.
#

rtf_evd = F.rtf_evd(psd_speech)
rtf_power = F.rtf_power(psd_speech, psd_noise, reference_channel=REFERENCE_CHANNEL)
moto's avatar
moto committed
364
365
366


######################################################################
367
368
369
370
371
372
# 6.2. Apply beamforming
# ~~~~~~~~~~~~~~~~~~~~~~
#
# :py:func:`torchaudio.transforms.RTFMVDR` takes the multi-channel
# complexed-valued STFT coefficients of the mixture speech, RTF matrix of target speech,
# PSD matrix of noise, and the reference channel inputs.
373
#
374
375
376
377
378
379
380
381
382
383
384
385
386
# The output is a single-channel complex-valued STFT coefficients of the enhanced speech.
# We can then obtain the enhanced waveform by passing this output to the
# :py:func:`torchaudio.transforms.InverseSpectrogram` module.

mvdr_transform = torchaudio.transforms.RTFMVDR()

# compute the enhanced speech based on F.rtf_evd
stft_rtf_evd = mvdr_transform(stft_mix, rtf_evd, psd_noise, reference_channel=REFERENCE_CHANNEL)
waveform_rtf_evd = istft(stft_rtf_evd, length=waveform_mix.shape[-1])

# compute the enhanced speech based on F.rtf_power
stft_rtf_power = mvdr_transform(stft_mix, rtf_power, psd_noise, reference_channel=REFERENCE_CHANNEL)
waveform_rtf_power = istft(stft_rtf_power, length=waveform_mix.shape[-1])
moto's avatar
moto committed
387
388
389


######################################################################
390
391
# 6.3. Result for RTFMVDR with `rtf_evd`
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
392
#
moto's avatar
moto committed
393

394
395
plot_spectrogram(stft_rtf_evd, "Enhanced Spectrogram by RTFMVDR and F.rtf_evd (dB)")
waveform_rtf_evd = waveform_rtf_evd.reshape(1, -1)
396
evaluate(waveform_rtf_evd, waveform_clean[0:1])
397
398
Audio(waveform_rtf_evd, rate=SAMPLE_RATE)

moto's avatar
moto committed
399
400

######################################################################
401
402
# 6.4. Result for RTFMVDR with `rtf_power`
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
403
#
moto's avatar
moto committed
404

405
plot_spectrogram(stft_rtf_power, "Enhanced Spectrogram by RTFMVDR and F.rtf_power (dB)")
406
waveform_rtf_power = waveform_rtf_power.reshape(1, -1)
407
evaluate(waveform_rtf_power, waveform_clean[0:1])
408
Audio(waveform_rtf_power, rate=SAMPLE_RATE)