mvdr_tutorial.py 8.4 KB
Newer Older
moto's avatar
moto committed
1
2
3
4
5
6
7
8
9
10
"""
MVDR with torchaudio
====================

**Author** `Zhaoheng Ni <zni@fb.com>`__

"""

######################################################################
# Overview
moto's avatar
moto committed
11
# --------
12
#
13
14
# This is a tutorial on how to apply MVDR beamforming with
# :py:func:`torchaudio.transforms.MVDR`.
15
#
moto's avatar
moto committed
16
# Steps
17
#
moto's avatar
moto committed
18
19
20
21
22
23
24
25
26
27
28
# - Ideal Ratio Mask (IRM) is generated by dividing the clean/noise
#   magnitude by the mixture magnitude.
# - We test all three solutions (``ref_channel``, ``stv_evd``, ``stv_power``)
#   of torchaudio's MVDR module.
# - We test the single-channel and multi-channel masks for MVDR beamforming.
#   The multi-channel mask is averaged along channel dimension when computing
#   the covariance matrices of speech and noise, respectively.


######################################################################
# Preparation
moto's avatar
moto committed
29
# -----------
30
#
moto's avatar
moto committed
31
# First, we import the necessary packages and retrieve the data.
32
#
moto's avatar
moto committed
33
34
35
# The multi-channel audio example is selected from
# `ConferencingSpeech <https://github.com/ConferencingSpeech/ConferencingSpeech2021>`__
# dataset.
36
#
moto's avatar
moto committed
37
# The original filename is
38
#
moto's avatar
moto committed
39
#    ``SSB07200001\#noise-sound-bible-0038\#7.86_6.16_3.00_3.14_4.84_134.5285_191.7899_0.4735\#15217\#25.16333303751458\#0.2101221178590021.wav``
40
#
moto's avatar
moto committed
41
# which was generated with;
42
#
moto's avatar
moto committed
43
# - ``SSB07200001.wav`` from `AISHELL-3 <https://www.openslr.org/93/>`__ (Apache License v.2.0)
44
45
# - ``noise-sound-bible-0038.wav`` from `MUSAN <http://www.openslr.org/17/>`__ (Attribution 4.0 International — CC BY 4.0)  # noqa: E501
#
moto's avatar
moto committed
46
47

import os
48
49

import IPython.display as ipd
moto's avatar
moto committed
50
51
52
53
54
import requests
import torch
import torchaudio

torch.random.manual_seed(0)
55
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
moto's avatar
moto committed
56
57
58
59
60
61

print(torch.__version__)
print(torchaudio.__version__)
print(device)

filenames = [
62
63
64
    "mix.wav",
    "reverb_clean.wav",
    "clean.wav",
moto's avatar
moto committed
65
]
66
base_url = "https://download.pytorch.org/torchaudio/tutorial-assets/mvdr"
moto's avatar
moto committed
67
68

for filename in filenames:
69
    os.makedirs("_assets", exist_ok=True)
moto's avatar
moto committed
70
    if not os.path.exists(filename):
71
72
        with open(f"_assets/{filename}", "wb") as file:
            file.write(requests.get(f"{base_url}/{filename}").content)
moto's avatar
moto committed
73
74
75

######################################################################
# Generate the Ideal Ratio Mask (IRM)
moto's avatar
moto committed
76
# -----------------------------------
77
#
moto's avatar
moto committed
78
79
80

######################################################################
# Loading audio data
moto's avatar
moto committed
81
# ~~~~~~~~~~~~~~~~~~
82
#
moto's avatar
moto committed
83

84
85
86
mix, sr = torchaudio.load("_assets/mix.wav")
reverb_clean, sr2 = torchaudio.load("_assets/reverb_clean.wav")
clean, sr3 = torchaudio.load("_assets/clean.wav")
moto's avatar
moto committed
87
88
89
90
91
assert sr == sr2

noise = mix - reverb_clean

######################################################################
92
#
moto's avatar
moto committed
93
94
95
# .. note::
#    The MVDR Module requires ``torch.cdouble`` dtype for noisy STFT.
#    We need to convert the dtype of the waveforms to ``torch.double``
96
#
moto's avatar
moto committed
97
98
99
100
101
102
103
104

mix = mix.to(torch.double)
noise = noise.to(torch.double)
clean = clean.to(torch.double)
reverb_clean = reverb_clean.to(torch.double)

######################################################################
# Compute STFT
moto's avatar
moto committed
105
# ~~~~~~~~~~~~
106
#
moto's avatar
moto committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121

stft = torchaudio.transforms.Spectrogram(
    n_fft=1024,
    hop_length=256,
    power=None,
)
istft = torchaudio.transforms.InverseSpectrogram(n_fft=1024, hop_length=256)

spec_mix = stft(mix)
spec_clean = stft(clean)
spec_reverb_clean = stft(reverb_clean)
spec_noise = stft(noise)

######################################################################
# Generate the Ideal Ratio Mask (IRM)
moto's avatar
moto committed
122
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
123
#
moto's avatar
moto committed
124
125
126
# .. note::
#    We found using the mask directly peforms better than using the
#    square root of it. This is slightly different from the definition of IRM.
127
#
moto's avatar
moto committed
128
129


130
def get_irms(spec_clean, spec_noise):
moto's avatar
moto committed
131
132
133
134
135
136
137
    mag_clean = spec_clean.abs() ** 2
    mag_noise = spec_noise.abs() ** 2
    irm_speech = mag_clean / (mag_clean + mag_noise)
    irm_noise = mag_noise / (mag_clean + mag_noise)

    return irm_speech, irm_noise

138

moto's avatar
moto committed
139
140
141
142
143
######################################################################
# .. note::
#    We use reverberant clean speech as the target here,
#    you can also set it to dry clean speech.

144
irm_speech, irm_noise = get_irms(spec_reverb_clean, spec_noise)
moto's avatar
moto committed
145
146
147

######################################################################
# Apply MVDR
moto's avatar
moto committed
148
# ----------
149
#
moto's avatar
moto committed
150
151
152

######################################################################
# Apply MVDR beamforming by using multi-channel masks
moto's avatar
moto committed
153
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
154
#
moto's avatar
moto committed
155
156

results_multi = {}
157
for solution in ["ref_channel", "stv_evd", "stv_power"]:
moto's avatar
moto committed
158
159
160
161
162
163
164
    mvdr = torchaudio.transforms.MVDR(ref_channel=0, solution=solution, multi_mask=True)
    stft_est = mvdr(spec_mix, irm_speech, irm_noise)
    est = istft(stft_est, length=mix.shape[-1])
    results_multi[solution] = est

######################################################################
# Apply MVDR beamforming by using single-channel masks
moto's avatar
moto committed
165
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
166
#
moto's avatar
moto committed
167
168
169
170
# We use the 1st channel as an example.
# The channel selection may depend on the design of the microphone array

results_single = {}
171
for solution in ["ref_channel", "stv_evd", "stv_power"]:
172
    mvdr = torchaudio.transforms.MVDR(ref_channel=0, solution=solution, multi_mask=False)
moto's avatar
moto committed
173
174
175
176
177
178
    stft_est = mvdr(spec_mix, irm_speech[0], irm_noise[0])
    est = istft(stft_est, length=mix.shape[-1])
    results_single[solution] = est

######################################################################
# Compute Si-SDR scores
moto's avatar
moto committed
179
# ~~~~~~~~~~~~~~~~~~~~~
180
181
#

moto's avatar
moto committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201

def si_sdr(estimate, reference, epsilon=1e-8):
    estimate = estimate - estimate.mean()
    reference = reference - reference.mean()
    reference_pow = reference.pow(2).mean(axis=1, keepdim=True)
    mix_pow = (estimate * reference).mean(axis=1, keepdim=True)
    scale = mix_pow / (reference_pow + epsilon)

    reference = scale * reference
    error = estimate - reference

    reference_pow = reference.pow(2)
    error_pow = error.pow(2)

    reference_pow = reference_pow.mean(axis=1)
    error_pow = error_pow.mean(axis=1)

    sisdr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)
    return sisdr.item()

202

moto's avatar
moto committed
203
204
######################################################################
# Results
moto's avatar
moto committed
205
# -------
206
#
moto's avatar
moto committed
207
208
209

######################################################################
# Single-channel mask results
moto's avatar
moto committed
210
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
211
#
moto's avatar
moto committed
212
213

for solution in results_single:
214
    print(solution + ": ", si_sdr(results_single[solution][None, ...], reverb_clean[0:1]))
moto's avatar
moto committed
215
216
217

######################################################################
# Multi-channel mask results
moto's avatar
moto committed
218
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
219
#
moto's avatar
moto committed
220
221

for solution in results_multi:
222
    print(solution + ": ", si_sdr(results_multi[solution][None, ...], reverb_clean[0:1]))
moto's avatar
moto committed
223
224
225

######################################################################
# Original audio
moto's avatar
moto committed
226
# --------------
227
#
moto's avatar
moto committed
228
229
230

######################################################################
# Mixture speech
moto's avatar
moto committed
231
# ~~~~~~~~~~~~~~
232
#
moto's avatar
moto committed
233
234
235
236
237

ipd.Audio(mix[0], rate=16000)

######################################################################
# Noise
moto's avatar
moto committed
238
# ~~~~~
239
#
moto's avatar
moto committed
240
241
242
243
244

ipd.Audio(noise[0], rate=16000)

######################################################################
# Clean speech
moto's avatar
moto committed
245
# ~~~~~~~~~~~~
246
#
moto's avatar
moto committed
247
248
249
250
251

ipd.Audio(clean[0], rate=16000)

######################################################################
# Enhanced audio
moto's avatar
moto committed
252
# --------------
253
#
moto's avatar
moto committed
254
255
256

######################################################################
# Multi-channel mask, ref_channel solution
moto's avatar
moto committed
257
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
258
#
moto's avatar
moto committed
259

260
ipd.Audio(results_multi["ref_channel"], rate=16000)
moto's avatar
moto committed
261
262
263

######################################################################
# Multi-channel mask, stv_evd solution
moto's avatar
moto committed
264
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
265
#
moto's avatar
moto committed
266

267
ipd.Audio(results_multi["stv_evd"], rate=16000)
moto's avatar
moto committed
268
269
270

######################################################################
# Multi-channel mask, stv_power solution
moto's avatar
moto committed
271
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
272
#
moto's avatar
moto committed
273

274
ipd.Audio(results_multi["stv_power"], rate=16000)
moto's avatar
moto committed
275
276
277

######################################################################
# Single-channel mask, ref_channel solution
moto's avatar
moto committed
278
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
279
#
moto's avatar
moto committed
280

281
ipd.Audio(results_single["ref_channel"], rate=16000)
moto's avatar
moto committed
282
283
284

######################################################################
# Single-channel mask, stv_evd solution
moto's avatar
moto committed
285
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
286
#
moto's avatar
moto committed
287

288
ipd.Audio(results_single["stv_evd"], rate=16000)
moto's avatar
moto committed
289
290
291

######################################################################
# Single-channel mask, stv_power solution
moto's avatar
moto committed
292
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
293
#
moto's avatar
moto committed
294

295
ipd.Audio(results_single["stv_power"], rate=16000)