mvdr_tutorial.py 8.42 KB
Newer Older
moto's avatar
moto committed
1
2
3
4
5
6
7
8
9
10
"""
MVDR with torchaudio
====================

**Author** `Zhaoheng Ni <zni@fb.com>`__

"""

######################################################################
# Overview
moto's avatar
moto committed
11
# --------
12
#
moto's avatar
moto committed
13
# This is a tutorial on how to apply MVDR beamforming by using `torchaudio <https://github.com/pytorch/audio>`__.
14
#
moto's avatar
moto committed
15
# Steps
16
#
moto's avatar
moto committed
17
18
19
20
21
22
23
24
25
26
27
# - Ideal Ratio Mask (IRM) is generated by dividing the clean/noise
#   magnitude by the mixture magnitude.
# - We test all three solutions (``ref_channel``, ``stv_evd``, ``stv_power``)
#   of torchaudio's MVDR module.
# - We test the single-channel and multi-channel masks for MVDR beamforming.
#   The multi-channel mask is averaged along channel dimension when computing
#   the covariance matrices of speech and noise, respectively.


######################################################################
# Preparation
moto's avatar
moto committed
28
# -----------
29
#
moto's avatar
moto committed
30
# First, we import the necessary packages and retrieve the data.
31
#
moto's avatar
moto committed
32
33
34
# The multi-channel audio example is selected from
# `ConferencingSpeech <https://github.com/ConferencingSpeech/ConferencingSpeech2021>`__
# dataset.
35
#
moto's avatar
moto committed
36
# The original filename is
37
#
moto's avatar
moto committed
38
#    ``SSB07200001\#noise-sound-bible-0038\#7.86_6.16_3.00_3.14_4.84_134.5285_191.7899_0.4735\#15217\#25.16333303751458\#0.2101221178590021.wav``
39
#
moto's avatar
moto committed
40
# which was generated with;
41
#
moto's avatar
moto committed
42
# - ``SSB07200001.wav`` from `AISHELL-3 <https://www.openslr.org/93/>`__ (Apache License v.2.0)
43
44
# - ``noise-sound-bible-0038.wav`` from `MUSAN <http://www.openslr.org/17/>`__ (Attribution 4.0 International — CC BY 4.0)  # noqa: E501
#
moto's avatar
moto committed
45
46

import os
47
48

import IPython.display as ipd
moto's avatar
moto committed
49
50
51
52
53
import requests
import torch
import torchaudio

torch.random.manual_seed(0)
54
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
moto's avatar
moto committed
55
56
57
58
59
60

print(torch.__version__)
print(torchaudio.__version__)
print(device)

filenames = [
61
62
63
    "mix.wav",
    "reverb_clean.wav",
    "clean.wav",
moto's avatar
moto committed
64
]
65
base_url = "https://download.pytorch.org/torchaudio/tutorial-assets/mvdr"
moto's avatar
moto committed
66
67

for filename in filenames:
68
    os.makedirs("_assets", exist_ok=True)
moto's avatar
moto committed
69
    if not os.path.exists(filename):
70
71
        with open(f"_assets/{filename}", "wb") as file:
            file.write(requests.get(f"{base_url}/{filename}").content)
moto's avatar
moto committed
72
73
74

######################################################################
# Generate the Ideal Ratio Mask (IRM)
moto's avatar
moto committed
75
# -----------------------------------
76
#
moto's avatar
moto committed
77
78
79

######################################################################
# Loading audio data
moto's avatar
moto committed
80
# ~~~~~~~~~~~~~~~~~~
81
#
moto's avatar
moto committed
82

83
84
85
mix, sr = torchaudio.load("_assets/mix.wav")
reverb_clean, sr2 = torchaudio.load("_assets/reverb_clean.wav")
clean, sr3 = torchaudio.load("_assets/clean.wav")
moto's avatar
moto committed
86
87
88
89
90
assert sr == sr2

noise = mix - reverb_clean

######################################################################
91
#
moto's avatar
moto committed
92
93
94
# .. note::
#    The MVDR Module requires ``torch.cdouble`` dtype for noisy STFT.
#    We need to convert the dtype of the waveforms to ``torch.double``
95
#
moto's avatar
moto committed
96
97
98
99
100
101
102
103

mix = mix.to(torch.double)
noise = noise.to(torch.double)
clean = clean.to(torch.double)
reverb_clean = reverb_clean.to(torch.double)

######################################################################
# Compute STFT
moto's avatar
moto committed
104
# ~~~~~~~~~~~~
105
#
moto's avatar
moto committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

stft = torchaudio.transforms.Spectrogram(
    n_fft=1024,
    hop_length=256,
    power=None,
)
istft = torchaudio.transforms.InverseSpectrogram(n_fft=1024, hop_length=256)

spec_mix = stft(mix)
spec_clean = stft(clean)
spec_reverb_clean = stft(reverb_clean)
spec_noise = stft(noise)

######################################################################
# Generate the Ideal Ratio Mask (IRM)
moto's avatar
moto committed
121
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
122
#
moto's avatar
moto committed
123
124
125
# .. note::
#    We found using the mask directly peforms better than using the
#    square root of it. This is slightly different from the definition of IRM.
126
#
moto's avatar
moto committed
127
128


129
def get_irms(spec_clean, spec_noise):
moto's avatar
moto committed
130
131
132
133
134
135
136
    mag_clean = spec_clean.abs() ** 2
    mag_noise = spec_noise.abs() ** 2
    irm_speech = mag_clean / (mag_clean + mag_noise)
    irm_noise = mag_noise / (mag_clean + mag_noise)

    return irm_speech, irm_noise

137

moto's avatar
moto committed
138
139
140
141
142
######################################################################
# .. note::
#    We use reverberant clean speech as the target here,
#    you can also set it to dry clean speech.

143
irm_speech, irm_noise = get_irms(spec_reverb_clean, spec_noise)
moto's avatar
moto committed
144
145
146

######################################################################
# Apply MVDR
moto's avatar
moto committed
147
# ----------
148
#
moto's avatar
moto committed
149
150
151

######################################################################
# Apply MVDR beamforming by using multi-channel masks
moto's avatar
moto committed
152
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
153
#
moto's avatar
moto committed
154
155

results_multi = {}
156
for solution in ["ref_channel", "stv_evd", "stv_power"]:
moto's avatar
moto committed
157
158
159
160
161
162
163
    mvdr = torchaudio.transforms.MVDR(ref_channel=0, solution=solution, multi_mask=True)
    stft_est = mvdr(spec_mix, irm_speech, irm_noise)
    est = istft(stft_est, length=mix.shape[-1])
    results_multi[solution] = est

######################################################################
# Apply MVDR beamforming by using single-channel masks
moto's avatar
moto committed
164
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
165
#
moto's avatar
moto committed
166
167
168
169
# We use the 1st channel as an example.
# The channel selection may depend on the design of the microphone array

results_single = {}
170
for solution in ["ref_channel", "stv_evd", "stv_power"]:
171
    mvdr = torchaudio.transforms.MVDR(ref_channel=0, solution=solution, multi_mask=False)
moto's avatar
moto committed
172
173
174
175
176
177
    stft_est = mvdr(spec_mix, irm_speech[0], irm_noise[0])
    est = istft(stft_est, length=mix.shape[-1])
    results_single[solution] = est

######################################################################
# Compute Si-SDR scores
moto's avatar
moto committed
178
# ~~~~~~~~~~~~~~~~~~~~~
179
180
#

moto's avatar
moto committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200

def si_sdr(estimate, reference, epsilon=1e-8):
    estimate = estimate - estimate.mean()
    reference = reference - reference.mean()
    reference_pow = reference.pow(2).mean(axis=1, keepdim=True)
    mix_pow = (estimate * reference).mean(axis=1, keepdim=True)
    scale = mix_pow / (reference_pow + epsilon)

    reference = scale * reference
    error = estimate - reference

    reference_pow = reference.pow(2)
    error_pow = error.pow(2)

    reference_pow = reference_pow.mean(axis=1)
    error_pow = error_pow.mean(axis=1)

    sisdr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)
    return sisdr.item()

201

moto's avatar
moto committed
202
203
######################################################################
# Results
moto's avatar
moto committed
204
# -------
205
#
moto's avatar
moto committed
206
207
208

######################################################################
# Single-channel mask results
moto's avatar
moto committed
209
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
210
#
moto's avatar
moto committed
211
212

for solution in results_single:
213
    print(solution + ": ", si_sdr(results_single[solution][None, ...], reverb_clean[0:1]))
moto's avatar
moto committed
214
215
216

######################################################################
# Multi-channel mask results
moto's avatar
moto committed
217
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
218
#
moto's avatar
moto committed
219
220

for solution in results_multi:
221
    print(solution + ": ", si_sdr(results_multi[solution][None, ...], reverb_clean[0:1]))
moto's avatar
moto committed
222
223
224

######################################################################
# Original audio
moto's avatar
moto committed
225
# --------------
226
#
moto's avatar
moto committed
227
228
229

######################################################################
# Mixture speech
moto's avatar
moto committed
230
# ~~~~~~~~~~~~~~
231
#
moto's avatar
moto committed
232
233
234
235
236

ipd.Audio(mix[0], rate=16000)

######################################################################
# Noise
moto's avatar
moto committed
237
# ~~~~~
238
#
moto's avatar
moto committed
239
240
241
242
243

ipd.Audio(noise[0], rate=16000)

######################################################################
# Clean speech
moto's avatar
moto committed
244
# ~~~~~~~~~~~~
245
#
moto's avatar
moto committed
246
247
248
249
250

ipd.Audio(clean[0], rate=16000)

######################################################################
# Enhanced audio
moto's avatar
moto committed
251
# --------------
252
#
moto's avatar
moto committed
253
254
255

######################################################################
# Multi-channel mask, ref_channel solution
moto's avatar
moto committed
256
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
257
#
moto's avatar
moto committed
258

259
ipd.Audio(results_multi["ref_channel"], rate=16000)
moto's avatar
moto committed
260
261
262

######################################################################
# Multi-channel mask, stv_evd solution
moto's avatar
moto committed
263
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
264
#
moto's avatar
moto committed
265

266
ipd.Audio(results_multi["stv_evd"], rate=16000)
moto's avatar
moto committed
267
268
269

######################################################################
# Multi-channel mask, stv_power solution
moto's avatar
moto committed
270
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
271
#
moto's avatar
moto committed
272

273
ipd.Audio(results_multi["stv_power"], rate=16000)
moto's avatar
moto committed
274
275
276

######################################################################
# Single-channel mask, ref_channel solution
moto's avatar
moto committed
277
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
278
#
moto's avatar
moto committed
279

280
ipd.Audio(results_single["ref_channel"], rate=16000)
moto's avatar
moto committed
281
282
283

######################################################################
# Single-channel mask, stv_evd solution
moto's avatar
moto committed
284
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
#
moto's avatar
moto committed
286

287
ipd.Audio(results_single["stv_evd"], rate=16000)
moto's avatar
moto committed
288
289
290

######################################################################
# Single-channel mask, stv_power solution
moto's avatar
moto committed
291
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
292
#
moto's avatar
moto committed
293

294
ipd.Audio(results_single["stv_power"], rate=16000)