mvdr_tutorial.py 11.2 KB
Newer Older
moto's avatar
moto committed
1
"""
2
3
Speech Enhancement with MVDR Beamforming
========================================
moto's avatar
moto committed
4
5
6
7
8

**Author** `Zhaoheng Ni <zni@fb.com>`__

"""

9

moto's avatar
moto committed
10
######################################################################
11
12
# 1. Overview
# -----------
13
#
14
15
16
# This is a tutorial on applying Minimum Variance Distortionless
# Response (MVDR) beamforming to estimate enhanced speech with
# TorchAudio.
17
#
18
19
20
21
22
23
24
25
26
27
28
29
# Steps:
#
# -  Generate an ideal ratio mask (IRM) by dividing the clean/noise
#    magnitude by the mixture magnitude.
# -  Estimate power spectral density (PSD) matrices using :py:func:`torchaudio.transforms.PSD`.
# -  Estimate enhanced speech using MVDR modules
#    (:py:func:`torchaudio.transforms.SoudenMVDR` and
#    :py:func:`torchaudio.transforms.RTFMVDR`).
# -  Benchmark the two methods
#    (:py:func:`torchaudio.functional.rtf_evd` and
#    :py:func:`torchaudio.functional.rtf_power`) for computing the
#    relative transfer function (RTF) matrix of the reference microphone.
30
#
31
32
33
34
35
36
37

import torch
import torchaudio
import torchaudio.functional as F

print(torch.__version__)
print(torchaudio.__version__)
moto's avatar
moto committed
38
39
40


######################################################################
41
42
# 2. Preparation
# --------------
43
#
moto's avatar
moto committed
44
# First, we import the necessary packages and retrieve the data.
45
#
moto's avatar
moto committed
46
47
48
# The multi-channel audio example is selected from
# `ConferencingSpeech <https://github.com/ConferencingSpeech/ConferencingSpeech2021>`__
# dataset.
49
#
moto's avatar
moto committed
50
# The original filename is
51
#
moto's avatar
moto committed
52
#    ``SSB07200001\#noise-sound-bible-0038\#7.86_6.16_3.00_3.14_4.84_134.5285_191.7899_0.4735\#15217\#25.16333303751458\#0.2101221178590021.wav``
53
#
54
# which was generated with:
55
#
56
57
58
59
60
61
# -  ``SSB07200001.wav`` from
#    `AISHELL-3 <https://www.openslr.org/93/>`__ (Apache License
#    v.2.0)
# -  ``noise-sound-bible-0038.wav`` from
#    `MUSAN <http://www.openslr.org/17/>`__ (Attribution 4.0
#    International — CC BY 4.0)
62
#
moto's avatar
moto committed
63

64
65
66
import matplotlib.pyplot as plt
from IPython.display import Audio
from torchaudio.utils import download_asset
67

68
69
70
SAMPLE_RATE = 16000
SAMPLE_CLEAN = download_asset("tutorial-assets/mvdr/clean_speech.wav")
SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav")
moto's avatar
moto committed
71
72


73
74
75
76
77
######################################################################
# 2.1. Helper functions
# ~~~~~~~~~~~~~~~~~~~~~
#

moto's avatar
moto committed
78

79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def plot_spectrogram(stft, title="Spectrogram", xlim=None):
    magnitude = stft.abs()
    spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
    figure, axis = plt.subplots(1, 1)
    img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto")
    figure.suptitle(title)
    plt.colorbar(img, ax=axis)
    plt.show()


def plot_mask(mask, title="Mask", xlim=None):
    mask = mask.numpy()
    figure, axis = plt.subplots(1, 1)
    img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto")
    figure.suptitle(title)
    plt.colorbar(img, ax=axis)
    plt.show()


def si_snr(estimate, reference, epsilon=1e-8):
    estimate = estimate - estimate.mean()
    reference = reference - reference.mean()
    reference_pow = reference.pow(2).mean(axis=1, keepdim=True)
    mix_pow = (estimate * reference).mean(axis=1, keepdim=True)
    scale = mix_pow / (reference_pow + epsilon)

    reference = scale * reference
    error = estimate - reference

    reference_pow = reference.pow(2)
    error_pow = error.pow(2)

    reference_pow = reference_pow.mean(axis=1)
    error_pow = error_pow.mean(axis=1)

    si_snr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)
    return si_snr.item()
moto's avatar
moto committed
116
117
118


######################################################################
119
120
# 3. Generate Ideal Ratio Masks (IRMs)
# ------------------------------------
121
#
moto's avatar
moto committed
122

123

moto's avatar
moto committed
124
######################################################################
125
126
# 3.1. Load audio data
# ~~~~~~~~~~~~~~~~~~~~
127
#
moto's avatar
moto committed
128

129
130
131
132
133
waveform_clean, sr = torchaudio.load(SAMPLE_CLEAN)
waveform_noise, sr2 = torchaudio.load(SAMPLE_NOISE)
assert sr == sr2 == SAMPLE_RATE
# The mixture waveform is a combination of clean and noise waveforms
waveform_mix = waveform_clean + waveform_noise
moto's avatar
moto committed
134
135
136


######################################################################
137
138
# Note: To improve computational robustness, it is recommended to represent
# the waveforms as double-precision floating point (``torch.float64`` or ``torch.double``) values.
139
#
moto's avatar
moto committed
140

141
142
143
144
waveform_mix = waveform_mix.to(torch.double)
waveform_clean = waveform_clean.to(torch.double)
waveform_noise = waveform_noise.to(torch.double)

moto's avatar
moto committed
145
146

######################################################################
147
148
# 3.2. Compute STFT coefficients
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
149
#
moto's avatar
moto committed
150

151
152
N_FFT = 1024
N_HOP = 256
moto's avatar
moto committed
153
stft = torchaudio.transforms.Spectrogram(
154
155
    n_fft=N_FFT,
    hop_length=N_HOP,
moto's avatar
moto committed
156
157
    power=None,
)
158
159
160
161
162
istft = torchaudio.transforms.InverseSpectrogram(n_fft=N_FFT, hop_length=N_HOP)

stft_mix = stft(waveform_mix)
stft_clean = stft(waveform_clean)
stft_noise = stft(waveform_noise)
moto's avatar
moto committed
163
164
165


######################################################################
166
167
# 3.2.1. Visualize mixture speech
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
168
#
moto's avatar
moto committed
169

170
171
plot_spectrogram(stft_mix[0], "Spectrogram of Mixture Speech (dB)")
Audio(waveform_mix[0], rate=SAMPLE_RATE)
moto's avatar
moto committed
172

173

moto's avatar
moto committed
174
######################################################################
175
176
177
# 3.2.2. Visualize clean speech
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
moto's avatar
moto committed
178

179
180
plot_spectrogram(stft_clean[0], "Spectrogram of Clean Speech (dB)")
Audio(waveform_clean[0], rate=SAMPLE_RATE)
moto's avatar
moto committed
181
182
183


######################################################################
184
185
# 3.2.3. Visualize noise
# ^^^^^^^^^^^^^^^^^^^^^^
186
#
moto's avatar
moto committed
187

188
189
190
plot_spectrogram(stft_noise[0], "Spectrogram of Noise (dB)")
Audio(waveform_noise[0], rate=SAMPLE_RATE)

moto's avatar
moto committed
191
192

######################################################################
193
194
195
196
197
# 3.3. Define the reference microphone
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# We choose the first microphone in the array as the reference channel for demonstration.
# The selection of the reference channel may depend on the design of the microphone array.
198
#
199
200
201
202
# You can also apply an end-to-end neural network which estimates both the reference channel and
# the PSD matrices, then obtains the enhanced STFT coefficients by the MVDR module.

REFERENCE_CHANNEL = 0
moto's avatar
moto committed
203
204
205


######################################################################
206
207
# 3.4. Compute IRMs
# ~~~~~~~~~~~~~~~~~
208
209
#

moto's avatar
moto committed
210

211
212
213
214
215
216
def get_irms(stft_clean, stft_noise):
    mag_clean = stft_clean.abs() ** 2
    mag_noise = stft_noise.abs() ** 2
    irm_speech = mag_clean / (mag_clean + mag_noise)
    irm_noise = mag_noise / (mag_clean + mag_noise)
    return irm_speech[REFERENCE_CHANNEL], irm_noise[REFERENCE_CHANNEL]
moto's avatar
moto committed
217
218


219
irm_speech, irm_noise = get_irms(stft_clean, stft_noise)
moto's avatar
moto committed
220

221

moto's avatar
moto committed
222
######################################################################
223
224
# 3.4.1. Visualize IRM of target speech
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
225
#
moto's avatar
moto committed
226

227
plot_mask(irm_speech, "IRM of the Target Speech")
moto's avatar
moto committed
228
229
230


######################################################################
231
232
# 3.4.2. Visualize IRM of noise
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
233
#
moto's avatar
moto committed
234

235
plot_mask(irm_noise, "IRM of the Noise")
moto's avatar
moto committed
236
237

######################################################################
238
239
# 4. Compute PSD matrices
# -----------------------
240
#
241
242
243
# :py:func:`torchaudio.transforms.PSD` computes the time-invariant PSD matrix given
# the multi-channel complex-valued STFT coefficients  of the mixture speech
# and the time-frequency mask.
244
#
245
# The shape of the PSD matrix is `(..., freq, channel, channel)`.
moto's avatar
moto committed
246

247
psd_transform = torchaudio.transforms.PSD()
moto's avatar
moto committed
248

249
250
psd_speech = psd_transform(stft_mix, irm_speech)
psd_noise = psd_transform(stft_mix, irm_noise)
moto's avatar
moto committed
251
252
253


######################################################################
254
255
# 5. Beamforming using SoudenMVDR
# -------------------------------
256
#
moto's avatar
moto committed
257
258
259


######################################################################
260
261
262
263
264
265
# 5.1. Apply beamforming
# ~~~~~~~~~~~~~~~~~~~~~~
#
# :py:func:`torchaudio.transforms.SoudenMVDR` takes the multi-channel
# complexed-valued STFT coefficients of the mixture speech, PSD matrices of
# target speech and noise, and the reference channel inputs.
266
#
267
268
269
270
271
272
273
274
# The output is a single-channel complex-valued STFT coefficients of the enhanced speech.
# We can then obtain the enhanced waveform by passing this output to the
# :py:func:`torchaudio.transforms.InverseSpectrogram` module.

mvdr_transform = torchaudio.transforms.SoudenMVDR()
stft_souden = mvdr_transform(stft_mix, psd_speech, psd_noise, reference_channel=REFERENCE_CHANNEL)
waveform_souden = istft(stft_souden, length=waveform_mix.shape[-1])

moto's avatar
moto committed
275
276

######################################################################
277
278
# 5.2. Result for SoudenMVDR
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
279
#
moto's avatar
moto committed
280

281
282
283
284
285
plot_spectrogram(stft_souden, "Enhanced Spectrogram by SoudenMVDR (dB)")
waveform_souden = waveform_souden.reshape(1, -1)
print(f"Si-SNR score: {si_snr(waveform_souden, waveform_clean[0:1])}")
Audio(waveform_souden, rate=SAMPLE_RATE)

moto's avatar
moto committed
286
287

######################################################################
288
289
# 6. Beamforming using RTFMVDR
# ----------------------------
290
#
moto's avatar
moto committed
291
292
293


######################################################################
294
295
296
297
298
299
300
301
# 6.1. Compute RTF
# ~~~~~~~~~~~~~~~~
#
# TorchAudio offers two methods for computing the RTF matrix of a
# target speech:
#
# -  :py:func:`torchaudio.functional.rtf_evd`, which applies eigenvalue
#    decomposition to the PSD matrix of target speech to get the RTF matrix.
302
#
303
304
305
306
307
308
# -  :py:func:`torchaudio.functional.rtf_power`, which applies the power iteration
#    method. You can specify the number of iterations with argument ``n_iter``.
#

rtf_evd = F.rtf_evd(psd_speech)
rtf_power = F.rtf_power(psd_speech, psd_noise, reference_channel=REFERENCE_CHANNEL)
moto's avatar
moto committed
309
310
311


######################################################################
312
313
314
315
316
317
# 6.2. Apply beamforming
# ~~~~~~~~~~~~~~~~~~~~~~
#
# :py:func:`torchaudio.transforms.RTFMVDR` takes the multi-channel
# complexed-valued STFT coefficients of the mixture speech, RTF matrix of target speech,
# PSD matrix of noise, and the reference channel inputs.
318
#
319
320
321
322
323
324
325
326
327
328
329
330
331
# The output is a single-channel complex-valued STFT coefficients of the enhanced speech.
# We can then obtain the enhanced waveform by passing this output to the
# :py:func:`torchaudio.transforms.InverseSpectrogram` module.

mvdr_transform = torchaudio.transforms.RTFMVDR()

# compute the enhanced speech based on F.rtf_evd
stft_rtf_evd = mvdr_transform(stft_mix, rtf_evd, psd_noise, reference_channel=REFERENCE_CHANNEL)
waveform_rtf_evd = istft(stft_rtf_evd, length=waveform_mix.shape[-1])

# compute the enhanced speech based on F.rtf_power
stft_rtf_power = mvdr_transform(stft_mix, rtf_power, psd_noise, reference_channel=REFERENCE_CHANNEL)
waveform_rtf_power = istft(stft_rtf_power, length=waveform_mix.shape[-1])
moto's avatar
moto committed
332
333
334


######################################################################
335
336
# 6.3. Result for RTFMVDR with `rtf_evd`
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
337
#
moto's avatar
moto committed
338

339
340
341
342
343
plot_spectrogram(stft_rtf_evd, "Enhanced Spectrogram by RTFMVDR and F.rtf_evd (dB)")
waveform_rtf_evd = waveform_rtf_evd.reshape(1, -1)
print(f"Si-SNR score: {si_snr(waveform_rtf_evd, waveform_clean[0:1])}")
Audio(waveform_rtf_evd, rate=SAMPLE_RATE)

moto's avatar
moto committed
344
345

######################################################################
346
347
# 6.4. Result for RTFMVDR with `rtf_power`
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
348
#
moto's avatar
moto committed
349

350
351
352
353
plot_spectrogram(stft_rtf_power, "Enhanced Spectrogram by RTFMVDR and F.rtf_evd (dB)")
waveform_rtf_power = waveform_rtf_power.reshape(1, -1)
print(f"Si-SNR score: {si_snr(waveform_rtf_power, waveform_clean[0:1])}")
Audio(waveform_rtf_power, rate=SAMPLE_RATE)