"...git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "688448db7547be90203440cfd105703d8a853f39"
Commit d01f5891 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Update MVDR beamforming tutorial (#2398)

Summary:
- Use `download_asset` to download audios.
- Replace `MVDR` module with new-added `SoudenMVDR` and `RTFMVDR` modules.
- Benchmark performances of `F.rtf_evd` and `F.rtf_power` for RTF computation.
- Visualize the spectrograms and masks.

Pull Request resolved: https://github.com/pytorch/audio/pull/2398

Reviewed By: carolineechen

Differential Revision: D36549402

Pulled By: nateanl

fbshipit-source-id: dfd6754e6c33246e6991ccc51c4603b12502a1b5
parent 19c60a08
""" """
MVDR with torchaudio Speech Enhancement with MVDR Beamforming
==================== ========================================
**Author** `Zhaoheng Ni <zni@fb.com>`__ **Author** `Zhaoheng Ni <zni@fb.com>`__
""" """
###################################################################### ######################################################################
# Overview # 1. Overview
# -------- # -----------
# #
# This is a tutorial on how to apply MVDR beamforming with # This is a tutorial on applying Minimum Variance Distortionless
# :py:func:`torchaudio.transforms.MVDR`. # Response (MVDR) beamforming to estimate enhanced speech with
# TorchAudio.
# #
# Steps # Steps:
#
# - Generate an ideal ratio mask (IRM) by dividing the clean/noise
# magnitude by the mixture magnitude.
# - Estimate power spectral density (PSD) matrices using :py:func:`torchaudio.transforms.PSD`.
# - Estimate enhanced speech using MVDR modules
# (:py:func:`torchaudio.transforms.SoudenMVDR` and
# :py:func:`torchaudio.transforms.RTFMVDR`).
# - Benchmark the two methods
# (:py:func:`torchaudio.functional.rtf_evd` and
# :py:func:`torchaudio.functional.rtf_power`) for computing the
# relative transfer function (RTF) matrix of the reference microphone.
# #
# - Ideal Ratio Mask (IRM) is generated by dividing the clean/noise
# magnitude by the mixture magnitude. import torch
# - We test all three solutions (``ref_channel``, ``stv_evd``, ``stv_power``) import torchaudio
# of torchaudio's MVDR module. import torchaudio.functional as F
# - We test the single-channel and multi-channel masks for MVDR beamforming.
# The multi-channel mask is averaged along channel dimension when computing print(torch.__version__)
# the covariance matrices of speech and noise, respectively. print(torchaudio.__version__)
###################################################################### ######################################################################
# Preparation # 2. Preparation
# ----------- # --------------
# #
# First, we import the necessary packages and retrieve the data. # First, we import the necessary packages and retrieve the data.
# #
...@@ -38,258 +51,303 @@ MVDR with torchaudio ...@@ -38,258 +51,303 @@ MVDR with torchaudio
# #
# ``SSB07200001\#noise-sound-bible-0038\#7.86_6.16_3.00_3.14_4.84_134.5285_191.7899_0.4735\#15217\#25.16333303751458\#0.2101221178590021.wav`` # ``SSB07200001\#noise-sound-bible-0038\#7.86_6.16_3.00_3.14_4.84_134.5285_191.7899_0.4735\#15217\#25.16333303751458\#0.2101221178590021.wav``
# #
# which was generated with; # which was generated with:
# #
# - ``SSB07200001.wav`` from `AISHELL-3 <https://www.openslr.org/93/>`__ (Apache License v.2.0) # - ``SSB07200001.wav`` from
# - ``noise-sound-bible-0038.wav`` from `MUSAN <http://www.openslr.org/17/>`__ (Attribution 4.0 International — CC BY 4.0) # noqa: E501 # `AISHELL-3 <https://www.openslr.org/93/>`__ (Apache License
# v.2.0)
# - ``noise-sound-bible-0038.wav`` from
# `MUSAN <http://www.openslr.org/17/>`__ (Attribution 4.0
# International — CC BY 4.0)
# #
import os import matplotlib.pyplot as plt
from IPython.display import Audio
from torchaudio.utils import download_asset
import IPython.display as ipd SAMPLE_RATE = 16000
import requests SAMPLE_CLEAN = download_asset("tutorial-assets/mvdr/clean_speech.wav")
import torch SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav")
import torchaudio
torch.random.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.__version__) ######################################################################
print(torchaudio.__version__) # 2.1. Helper functions
print(device) # ~~~~~~~~~~~~~~~~~~~~~
#
filenames = [ def plot_spectrogram(stft, title="Spectrogram", xlim=None):
"mix.wav", magnitude = stft.abs()
"reverb_clean.wav", spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
"clean.wav", figure, axis = plt.subplots(1, 1)
] img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto")
base_url = "https://download.pytorch.org/torchaudio/tutorial-assets/mvdr" figure.suptitle(title)
plt.colorbar(img, ax=axis)
plt.show()
def plot_mask(mask, title="Mask", xlim=None):
mask = mask.numpy()
figure, axis = plt.subplots(1, 1)
img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto")
figure.suptitle(title)
plt.colorbar(img, ax=axis)
plt.show()
def si_snr(estimate, reference, epsilon=1e-8):
estimate = estimate - estimate.mean()
reference = reference - reference.mean()
reference_pow = reference.pow(2).mean(axis=1, keepdim=True)
mix_pow = (estimate * reference).mean(axis=1, keepdim=True)
scale = mix_pow / (reference_pow + epsilon)
reference = scale * reference
error = estimate - reference
reference_pow = reference.pow(2)
error_pow = error.pow(2)
reference_pow = reference_pow.mean(axis=1)
error_pow = error_pow.mean(axis=1)
si_snr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)
return si_snr.item()
for filename in filenames:
os.makedirs("_assets", exist_ok=True)
if not os.path.exists(filename):
with open(f"_assets/{filename}", "wb") as file:
file.write(requests.get(f"{base_url}/{filename}").content)
###################################################################### ######################################################################
# Generate the Ideal Ratio Mask (IRM) # 3. Generate Ideal Ratio Masks (IRMs)
# ----------------------------------- # ------------------------------------
# #
###################################################################### ######################################################################
# Loading audio data # 3.1. Load audio data
# ~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~
# #
mix, sr = torchaudio.load("_assets/mix.wav") waveform_clean, sr = torchaudio.load(SAMPLE_CLEAN)
reverb_clean, sr2 = torchaudio.load("_assets/reverb_clean.wav") waveform_noise, sr2 = torchaudio.load(SAMPLE_NOISE)
clean, sr3 = torchaudio.load("_assets/clean.wav") assert sr == sr2 == SAMPLE_RATE
assert sr == sr2 # The mixture waveform is a combination of clean and noise waveforms
waveform_mix = waveform_clean + waveform_noise
noise = mix - reverb_clean
###################################################################### ######################################################################
# # Note: To improve computational robustness, it is recommended to represent
# .. note:: # the waveforms as double-precision floating point (``torch.float64`` or ``torch.double``) values.
# The MVDR Module requires ``torch.cdouble`` dtype for noisy STFT.
# We need to convert the dtype of the waveforms to ``torch.double``
# #
mix = mix.to(torch.double) waveform_mix = waveform_mix.to(torch.double)
noise = noise.to(torch.double) waveform_clean = waveform_clean.to(torch.double)
clean = clean.to(torch.double) waveform_noise = waveform_noise.to(torch.double)
reverb_clean = reverb_clean.to(torch.double)
###################################################################### ######################################################################
# Compute STFT # 3.2. Compute STFT coefficients
# ~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
N_FFT = 1024
N_HOP = 256
stft = torchaudio.transforms.Spectrogram( stft = torchaudio.transforms.Spectrogram(
n_fft=1024, n_fft=N_FFT,
hop_length=256, hop_length=N_HOP,
power=None, power=None,
) )
istft = torchaudio.transforms.InverseSpectrogram(n_fft=1024, hop_length=256) istft = torchaudio.transforms.InverseSpectrogram(n_fft=N_FFT, hop_length=N_HOP)
stft_mix = stft(waveform_mix)
stft_clean = stft(waveform_clean)
stft_noise = stft(waveform_noise)
spec_mix = stft(mix)
spec_clean = stft(clean)
spec_reverb_clean = stft(reverb_clean)
spec_noise = stft(noise)
###################################################################### ######################################################################
# Generate the Ideal Ratio Mask (IRM) # 3.2.1. Visualize mixture speech
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# .. note::
# We found using the mask directly peforms better than using the
# square root of it. This is slightly different from the definition of IRM.
# #
plot_spectrogram(stft_mix[0], "Spectrogram of Mixture Speech (dB)")
def get_irms(spec_clean, spec_noise): Audio(waveform_mix[0], rate=SAMPLE_RATE)
mag_clean = spec_clean.abs() ** 2
mag_noise = spec_noise.abs() ** 2
irm_speech = mag_clean / (mag_clean + mag_noise)
irm_noise = mag_noise / (mag_clean + mag_noise)
return irm_speech, irm_noise
###################################################################### ######################################################################
# .. note:: # 3.2.2. Visualize clean speech
# We use reverberant clean speech as the target here, # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# you can also set it to dry clean speech. #
irm_speech, irm_noise = get_irms(spec_reverb_clean, spec_noise) plot_spectrogram(stft_clean[0], "Spectrogram of Clean Speech (dB)")
Audio(waveform_clean[0], rate=SAMPLE_RATE)
######################################################################
# Apply MVDR
# ----------
#
###################################################################### ######################################################################
# Apply MVDR beamforming by using multi-channel masks # 3.2.3. Visualize noise
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ^^^^^^^^^^^^^^^^^^^^^^
# #
results_multi = {} plot_spectrogram(stft_noise[0], "Spectrogram of Noise (dB)")
for solution in ["ref_channel", "stv_evd", "stv_power"]: Audio(waveform_noise[0], rate=SAMPLE_RATE)
mvdr = torchaudio.transforms.MVDR(ref_channel=0, solution=solution, multi_mask=True)
stft_est = mvdr(spec_mix, irm_speech, irm_noise)
est = istft(stft_est, length=mix.shape[-1])
results_multi[solution] = est
###################################################################### ######################################################################
# Apply MVDR beamforming by using single-channel masks # 3.3. Define the reference microphone
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# We choose the first microphone in the array as the reference channel for demonstration.
# The selection of the reference channel may depend on the design of the microphone array.
# #
# We use the 1st channel as an example. # You can also apply an end-to-end neural network which estimates both the reference channel and
# The channel selection may depend on the design of the microphone array # the PSD matrices, then obtains the enhanced STFT coefficients by the MVDR module.
REFERENCE_CHANNEL = 0
results_single = {}
for solution in ["ref_channel", "stv_evd", "stv_power"]:
mvdr = torchaudio.transforms.MVDR(ref_channel=0, solution=solution, multi_mask=False)
stft_est = mvdr(spec_mix, irm_speech[0], irm_noise[0])
est = istft(stft_est, length=mix.shape[-1])
results_single[solution] = est
###################################################################### ######################################################################
# Compute Si-SDR scores # 3.4. Compute IRMs
# ~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~
# #
def si_sdr(estimate, reference, epsilon=1e-8): def get_irms(stft_clean, stft_noise):
estimate = estimate - estimate.mean() mag_clean = stft_clean.abs() ** 2
reference = reference - reference.mean() mag_noise = stft_noise.abs() ** 2
reference_pow = reference.pow(2).mean(axis=1, keepdim=True) irm_speech = mag_clean / (mag_clean + mag_noise)
mix_pow = (estimate * reference).mean(axis=1, keepdim=True) irm_noise = mag_noise / (mag_clean + mag_noise)
scale = mix_pow / (reference_pow + epsilon) return irm_speech[REFERENCE_CHANNEL], irm_noise[REFERENCE_CHANNEL]
reference = scale * reference
error = estimate - reference
reference_pow = reference.pow(2)
error_pow = error.pow(2)
reference_pow = reference_pow.mean(axis=1)
error_pow = error_pow.mean(axis=1)
sisdr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow) irm_speech, irm_noise = get_irms(stft_clean, stft_noise)
return sisdr.item()
###################################################################### ######################################################################
# Results # 3.4.1. Visualize IRM of target speech
# ------- # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# #
###################################################################### plot_mask(irm_speech, "IRM of the Target Speech")
# Single-channel mask results
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
for solution in results_single:
print(solution + ": ", si_sdr(results_single[solution][None, ...], reverb_clean[0:1]))
###################################################################### ######################################################################
# Multi-channel mask results # 3.4.2. Visualize IRM of noise
# ~~~~~~~~~~~~~~~~~~~~~~~~~~ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# #
for solution in results_multi: plot_mask(irm_noise, "IRM of the Noise")
print(solution + ": ", si_sdr(results_multi[solution][None, ...], reverb_clean[0:1]))
###################################################################### ######################################################################
# Original audio # 4. Compute PSD matrices
# -------------- # -----------------------
# #
# :py:func:`torchaudio.transforms.PSD` computes the time-invariant PSD matrix given
###################################################################### # the multi-channel complex-valued STFT coefficients of the mixture speech
# Mixture speech # and the time-frequency mask.
# ~~~~~~~~~~~~~~
# #
# The shape of the PSD matrix is `(..., freq, channel, channel)`.
ipd.Audio(mix[0], rate=16000) psd_transform = torchaudio.transforms.PSD()
###################################################################### psd_speech = psd_transform(stft_mix, irm_speech)
# Noise psd_noise = psd_transform(stft_mix, irm_noise)
# ~~~~~
#
ipd.Audio(noise[0], rate=16000)
###################################################################### ######################################################################
# Clean speech # 5. Beamforming using SoudenMVDR
# ~~~~~~~~~~~~ # -------------------------------
# #
ipd.Audio(clean[0], rate=16000)
###################################################################### ######################################################################
# Enhanced audio # 5.1. Apply beamforming
# -------------- # ~~~~~~~~~~~~~~~~~~~~~~
#
# :py:func:`torchaudio.transforms.SoudenMVDR` takes the multi-channel
# complexed-valued STFT coefficients of the mixture speech, PSD matrices of
# target speech and noise, and the reference channel inputs.
# #
# The output is a single-channel complex-valued STFT coefficients of the enhanced speech.
# We can then obtain the enhanced waveform by passing this output to the
# :py:func:`torchaudio.transforms.InverseSpectrogram` module.
mvdr_transform = torchaudio.transforms.SoudenMVDR()
stft_souden = mvdr_transform(stft_mix, psd_speech, psd_noise, reference_channel=REFERENCE_CHANNEL)
waveform_souden = istft(stft_souden, length=waveform_mix.shape[-1])
###################################################################### ######################################################################
# Multi-channel mask, ref_channel solution # 5.2. Result for SoudenMVDR
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
ipd.Audio(results_multi["ref_channel"], rate=16000) plot_spectrogram(stft_souden, "Enhanced Spectrogram by SoudenMVDR (dB)")
waveform_souden = waveform_souden.reshape(1, -1)
print(f"Si-SNR score: {si_snr(waveform_souden, waveform_clean[0:1])}")
Audio(waveform_souden, rate=SAMPLE_RATE)
###################################################################### ######################################################################
# Multi-channel mask, stv_evd solution # 6. Beamforming using RTFMVDR
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ----------------------------
# #
ipd.Audio(results_multi["stv_evd"], rate=16000)
###################################################################### ######################################################################
# Multi-channel mask, stv_power solution # 6.1. Compute RTF
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~
#
# TorchAudio offers two methods for computing the RTF matrix of a
# target speech:
#
# - :py:func:`torchaudio.functional.rtf_evd`, which applies eigenvalue
# decomposition to the PSD matrix of target speech to get the RTF matrix.
# #
# - :py:func:`torchaudio.functional.rtf_power`, which applies the power iteration
# method. You can specify the number of iterations with argument ``n_iter``.
#
rtf_evd = F.rtf_evd(psd_speech)
rtf_power = F.rtf_power(psd_speech, psd_noise, reference_channel=REFERENCE_CHANNEL)
ipd.Audio(results_multi["stv_power"], rate=16000)
###################################################################### ######################################################################
# Single-channel mask, ref_channel solution # 6.2. Apply beamforming
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~
#
# :py:func:`torchaudio.transforms.RTFMVDR` takes the multi-channel
# complexed-valued STFT coefficients of the mixture speech, RTF matrix of target speech,
# PSD matrix of noise, and the reference channel inputs.
# #
# The output is a single-channel complex-valued STFT coefficients of the enhanced speech.
# We can then obtain the enhanced waveform by passing this output to the
# :py:func:`torchaudio.transforms.InverseSpectrogram` module.
mvdr_transform = torchaudio.transforms.RTFMVDR()
# compute the enhanced speech based on F.rtf_evd
stft_rtf_evd = mvdr_transform(stft_mix, rtf_evd, psd_noise, reference_channel=REFERENCE_CHANNEL)
waveform_rtf_evd = istft(stft_rtf_evd, length=waveform_mix.shape[-1])
# compute the enhanced speech based on F.rtf_power
stft_rtf_power = mvdr_transform(stft_mix, rtf_power, psd_noise, reference_channel=REFERENCE_CHANNEL)
waveform_rtf_power = istft(stft_rtf_power, length=waveform_mix.shape[-1])
ipd.Audio(results_single["ref_channel"], rate=16000)
###################################################################### ######################################################################
# Single-channel mask, stv_evd solution # 6.3. Result for RTFMVDR with `rtf_evd`
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
ipd.Audio(results_single["stv_evd"], rate=16000) plot_spectrogram(stft_rtf_evd, "Enhanced Spectrogram by RTFMVDR and F.rtf_evd (dB)")
waveform_rtf_evd = waveform_rtf_evd.reshape(1, -1)
print(f"Si-SNR score: {si_snr(waveform_rtf_evd, waveform_clean[0:1])}")
Audio(waveform_rtf_evd, rate=SAMPLE_RATE)
###################################################################### ######################################################################
# Single-channel mask, stv_power solution # 6.4. Result for RTFMVDR with `rtf_power`
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
ipd.Audio(results_single["stv_power"], rate=16000) plot_spectrogram(stft_rtf_power, "Enhanced Spectrogram by RTFMVDR and F.rtf_evd (dB)")
waveform_rtf_power = waveform_rtf_power.reshape(1, -1)
print(f"Si-SNR score: {si_snr(waveform_rtf_power, waveform_clean[0:1])}")
Audio(waveform_rtf_power, rate=SAMPLE_RATE)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment