"torchvision/vscode:/vscode.git/clone" did not exist on "138b5c259578e2d4e2e2c4f538beac7a1652a551"
Commit d01f5891 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Update MVDR beamforming tutorial (#2398)

Summary:
- Use `download_asset` to download audios.
- Replace `MVDR` module with new-added `SoudenMVDR` and `RTFMVDR` modules.
- Benchmark performances of `F.rtf_evd` and `F.rtf_power` for RTF computation.
- Visualize the spectrograms and masks.

Pull Request resolved: https://github.com/pytorch/audio/pull/2398

Reviewed By: carolineechen

Differential Revision: D36549402

Pulled By: nateanl

fbshipit-source-id: dfd6754e6c33246e6991ccc51c4603b12502a1b5
parent 19c60a08
"""
MVDR with torchaudio
====================
Speech Enhancement with MVDR Beamforming
========================================
**Author** `Zhaoheng Ni <zni@fb.com>`__
"""
######################################################################
# Overview
# --------
# 1. Overview
# -----------
#
# This is a tutorial on how to apply MVDR beamforming with
# :py:func:`torchaudio.transforms.MVDR`.
# This is a tutorial on applying Minimum Variance Distortionless
# Response (MVDR) beamforming to estimate enhanced speech with
# TorchAudio.
#
# Steps
# Steps:
#
# - Generate an ideal ratio mask (IRM) by dividing the clean/noise
# magnitude by the mixture magnitude.
# - Estimate power spectral density (PSD) matrices using :py:func:`torchaudio.transforms.PSD`.
# - Estimate enhanced speech using MVDR modules
# (:py:func:`torchaudio.transforms.SoudenMVDR` and
# :py:func:`torchaudio.transforms.RTFMVDR`).
# - Benchmark the two methods
# (:py:func:`torchaudio.functional.rtf_evd` and
# :py:func:`torchaudio.functional.rtf_power`) for computing the
# relative transfer function (RTF) matrix of the reference microphone.
#
# - Ideal Ratio Mask (IRM) is generated by dividing the clean/noise
# magnitude by the mixture magnitude.
# - We test all three solutions (``ref_channel``, ``stv_evd``, ``stv_power``)
# of torchaudio's MVDR module.
# - We test the single-channel and multi-channel masks for MVDR beamforming.
# The multi-channel mask is averaged along channel dimension when computing
# the covariance matrices of speech and noise, respectively.
import torch
import torchaudio
import torchaudio.functional as F
print(torch.__version__)
print(torchaudio.__version__)
######################################################################
# Preparation
# -----------
# 2. Preparation
# --------------
#
# First, we import the necessary packages and retrieve the data.
#
......@@ -38,258 +51,303 @@ MVDR with torchaudio
#
# ``SSB07200001\#noise-sound-bible-0038\#7.86_6.16_3.00_3.14_4.84_134.5285_191.7899_0.4735\#15217\#25.16333303751458\#0.2101221178590021.wav``
#
# which was generated with;
# which was generated with:
#
# - ``SSB07200001.wav`` from `AISHELL-3 <https://www.openslr.org/93/>`__ (Apache License v.2.0)
# - ``noise-sound-bible-0038.wav`` from `MUSAN <http://www.openslr.org/17/>`__ (Attribution 4.0 International — CC BY 4.0) # noqa: E501
# - ``SSB07200001.wav`` from
# `AISHELL-3 <https://www.openslr.org/93/>`__ (Apache License
# v.2.0)
# - ``noise-sound-bible-0038.wav`` from
# `MUSAN <http://www.openslr.org/17/>`__ (Attribution 4.0
# International — CC BY 4.0)
#
import os
import matplotlib.pyplot as plt
from IPython.display import Audio
from torchaudio.utils import download_asset
import IPython.display as ipd
import requests
import torch
import torchaudio
SAMPLE_RATE = 16000
SAMPLE_CLEAN = download_asset("tutorial-assets/mvdr/clean_speech.wav")
SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav")
torch.random.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.__version__)
print(torchaudio.__version__)
print(device)
######################################################################
# 2.1. Helper functions
# ~~~~~~~~~~~~~~~~~~~~~
#
filenames = [
"mix.wav",
"reverb_clean.wav",
"clean.wav",
]
base_url = "https://download.pytorch.org/torchaudio/tutorial-assets/mvdr"
def plot_spectrogram(stft, title="Spectrogram", xlim=None):
magnitude = stft.abs()
spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
figure, axis = plt.subplots(1, 1)
img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto")
figure.suptitle(title)
plt.colorbar(img, ax=axis)
plt.show()
def plot_mask(mask, title="Mask", xlim=None):
mask = mask.numpy()
figure, axis = plt.subplots(1, 1)
img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto")
figure.suptitle(title)
plt.colorbar(img, ax=axis)
plt.show()
def si_snr(estimate, reference, epsilon=1e-8):
estimate = estimate - estimate.mean()
reference = reference - reference.mean()
reference_pow = reference.pow(2).mean(axis=1, keepdim=True)
mix_pow = (estimate * reference).mean(axis=1, keepdim=True)
scale = mix_pow / (reference_pow + epsilon)
reference = scale * reference
error = estimate - reference
reference_pow = reference.pow(2)
error_pow = error.pow(2)
reference_pow = reference_pow.mean(axis=1)
error_pow = error_pow.mean(axis=1)
si_snr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)
return si_snr.item()
for filename in filenames:
os.makedirs("_assets", exist_ok=True)
if not os.path.exists(filename):
with open(f"_assets/{filename}", "wb") as file:
file.write(requests.get(f"{base_url}/{filename}").content)
######################################################################
# Generate the Ideal Ratio Mask (IRM)
# -----------------------------------
# 3. Generate Ideal Ratio Masks (IRMs)
# ------------------------------------
#
######################################################################
# Loading audio data
# ~~~~~~~~~~~~~~~~~~
# 3.1. Load audio data
# ~~~~~~~~~~~~~~~~~~~~
#
mix, sr = torchaudio.load("_assets/mix.wav")
reverb_clean, sr2 = torchaudio.load("_assets/reverb_clean.wav")
clean, sr3 = torchaudio.load("_assets/clean.wav")
assert sr == sr2
waveform_clean, sr = torchaudio.load(SAMPLE_CLEAN)
waveform_noise, sr2 = torchaudio.load(SAMPLE_NOISE)
assert sr == sr2 == SAMPLE_RATE
# The mixture waveform is a combination of clean and noise waveforms
waveform_mix = waveform_clean + waveform_noise
noise = mix - reverb_clean
######################################################################
#
# .. note::
# The MVDR Module requires ``torch.cdouble`` dtype for noisy STFT.
# We need to convert the dtype of the waveforms to ``torch.double``
# Note: To improve computational robustness, it is recommended to represent
# the waveforms as double-precision floating point (``torch.float64`` or ``torch.double``) values.
#
mix = mix.to(torch.double)
noise = noise.to(torch.double)
clean = clean.to(torch.double)
reverb_clean = reverb_clean.to(torch.double)
waveform_mix = waveform_mix.to(torch.double)
waveform_clean = waveform_clean.to(torch.double)
waveform_noise = waveform_noise.to(torch.double)
######################################################################
# Compute STFT
# ~~~~~~~~~~~~
# 3.2. Compute STFT coefficients
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
N_FFT = 1024
N_HOP = 256
stft = torchaudio.transforms.Spectrogram(
n_fft=1024,
hop_length=256,
n_fft=N_FFT,
hop_length=N_HOP,
power=None,
)
istft = torchaudio.transforms.InverseSpectrogram(n_fft=1024, hop_length=256)
istft = torchaudio.transforms.InverseSpectrogram(n_fft=N_FFT, hop_length=N_HOP)
stft_mix = stft(waveform_mix)
stft_clean = stft(waveform_clean)
stft_noise = stft(waveform_noise)
spec_mix = stft(mix)
spec_clean = stft(clean)
spec_reverb_clean = stft(reverb_clean)
spec_noise = stft(noise)
######################################################################
# Generate the Ideal Ratio Mask (IRM)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# .. note::
# We found using the mask directly peforms better than using the
# square root of it. This is slightly different from the definition of IRM.
# 3.2.1. Visualize mixture speech
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
def get_irms(spec_clean, spec_noise):
mag_clean = spec_clean.abs() ** 2
mag_noise = spec_noise.abs() ** 2
irm_speech = mag_clean / (mag_clean + mag_noise)
irm_noise = mag_noise / (mag_clean + mag_noise)
return irm_speech, irm_noise
plot_spectrogram(stft_mix[0], "Spectrogram of Mixture Speech (dB)")
Audio(waveform_mix[0], rate=SAMPLE_RATE)
######################################################################
# .. note::
# We use reverberant clean speech as the target here,
# you can also set it to dry clean speech.
# 3.2.2. Visualize clean speech
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
irm_speech, irm_noise = get_irms(spec_reverb_clean, spec_noise)
plot_spectrogram(stft_clean[0], "Spectrogram of Clean Speech (dB)")
Audio(waveform_clean[0], rate=SAMPLE_RATE)
######################################################################
# Apply MVDR
# ----------
#
######################################################################
# Apply MVDR beamforming by using multi-channel masks
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 3.2.3. Visualize noise
# ^^^^^^^^^^^^^^^^^^^^^^
#
results_multi = {}
for solution in ["ref_channel", "stv_evd", "stv_power"]:
mvdr = torchaudio.transforms.MVDR(ref_channel=0, solution=solution, multi_mask=True)
stft_est = mvdr(spec_mix, irm_speech, irm_noise)
est = istft(stft_est, length=mix.shape[-1])
results_multi[solution] = est
plot_spectrogram(stft_noise[0], "Spectrogram of Noise (dB)")
Audio(waveform_noise[0], rate=SAMPLE_RATE)
######################################################################
# Apply MVDR beamforming by using single-channel masks
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 3.3. Define the reference microphone
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# We choose the first microphone in the array as the reference channel for demonstration.
# The selection of the reference channel may depend on the design of the microphone array.
#
# We use the 1st channel as an example.
# The channel selection may depend on the design of the microphone array
# You can also apply an end-to-end neural network which estimates both the reference channel and
# the PSD matrices, then obtains the enhanced STFT coefficients by the MVDR module.
REFERENCE_CHANNEL = 0
results_single = {}
for solution in ["ref_channel", "stv_evd", "stv_power"]:
mvdr = torchaudio.transforms.MVDR(ref_channel=0, solution=solution, multi_mask=False)
stft_est = mvdr(spec_mix, irm_speech[0], irm_noise[0])
est = istft(stft_est, length=mix.shape[-1])
results_single[solution] = est
######################################################################
# Compute Si-SDR scores
# ~~~~~~~~~~~~~~~~~~~~~
# 3.4. Compute IRMs
# ~~~~~~~~~~~~~~~~~
#
def si_sdr(estimate, reference, epsilon=1e-8):
estimate = estimate - estimate.mean()
reference = reference - reference.mean()
reference_pow = reference.pow(2).mean(axis=1, keepdim=True)
mix_pow = (estimate * reference).mean(axis=1, keepdim=True)
scale = mix_pow / (reference_pow + epsilon)
reference = scale * reference
error = estimate - reference
reference_pow = reference.pow(2)
error_pow = error.pow(2)
def get_irms(stft_clean, stft_noise):
mag_clean = stft_clean.abs() ** 2
mag_noise = stft_noise.abs() ** 2
irm_speech = mag_clean / (mag_clean + mag_noise)
irm_noise = mag_noise / (mag_clean + mag_noise)
return irm_speech[REFERENCE_CHANNEL], irm_noise[REFERENCE_CHANNEL]
reference_pow = reference_pow.mean(axis=1)
error_pow = error_pow.mean(axis=1)
sisdr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)
return sisdr.item()
irm_speech, irm_noise = get_irms(stft_clean, stft_noise)
######################################################################
# Results
# -------
# 3.4.1. Visualize IRM of target speech
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
######################################################################
# Single-channel mask results
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
plot_mask(irm_speech, "IRM of the Target Speech")
for solution in results_single:
print(solution + ": ", si_sdr(results_single[solution][None, ...], reverb_clean[0:1]))
######################################################################
# Multi-channel mask results
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# 3.4.2. Visualize IRM of noise
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
for solution in results_multi:
print(solution + ": ", si_sdr(results_multi[solution][None, ...], reverb_clean[0:1]))
plot_mask(irm_noise, "IRM of the Noise")
######################################################################
# Original audio
# --------------
# 4. Compute PSD matrices
# -----------------------
#
######################################################################
# Mixture speech
# ~~~~~~~~~~~~~~
# :py:func:`torchaudio.transforms.PSD` computes the time-invariant PSD matrix given
# the multi-channel complex-valued STFT coefficients of the mixture speech
# and the time-frequency mask.
#
# The shape of the PSD matrix is `(..., freq, channel, channel)`.
ipd.Audio(mix[0], rate=16000)
psd_transform = torchaudio.transforms.PSD()
######################################################################
# Noise
# ~~~~~
#
psd_speech = psd_transform(stft_mix, irm_speech)
psd_noise = psd_transform(stft_mix, irm_noise)
ipd.Audio(noise[0], rate=16000)
######################################################################
# Clean speech
# ~~~~~~~~~~~~
# 5. Beamforming using SoudenMVDR
# -------------------------------
#
ipd.Audio(clean[0], rate=16000)
######################################################################
# Enhanced audio
# --------------
# 5.1. Apply beamforming
# ~~~~~~~~~~~~~~~~~~~~~~
#
# :py:func:`torchaudio.transforms.SoudenMVDR` takes the multi-channel
# complexed-valued STFT coefficients of the mixture speech, PSD matrices of
# target speech and noise, and the reference channel inputs.
#
# The output is a single-channel complex-valued STFT coefficients of the enhanced speech.
# We can then obtain the enhanced waveform by passing this output to the
# :py:func:`torchaudio.transforms.InverseSpectrogram` module.
mvdr_transform = torchaudio.transforms.SoudenMVDR()
stft_souden = mvdr_transform(stft_mix, psd_speech, psd_noise, reference_channel=REFERENCE_CHANNEL)
waveform_souden = istft(stft_souden, length=waveform_mix.shape[-1])
######################################################################
# Multi-channel mask, ref_channel solution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 5.2. Result for SoudenMVDR
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
#
ipd.Audio(results_multi["ref_channel"], rate=16000)
plot_spectrogram(stft_souden, "Enhanced Spectrogram by SoudenMVDR (dB)")
waveform_souden = waveform_souden.reshape(1, -1)
print(f"Si-SNR score: {si_snr(waveform_souden, waveform_clean[0:1])}")
Audio(waveform_souden, rate=SAMPLE_RATE)
######################################################################
# Multi-channel mask, stv_evd solution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 6. Beamforming using RTFMVDR
# ----------------------------
#
ipd.Audio(results_multi["stv_evd"], rate=16000)
######################################################################
# Multi-channel mask, stv_power solution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 6.1. Compute RTF
# ~~~~~~~~~~~~~~~~
#
# TorchAudio offers two methods for computing the RTF matrix of a
# target speech:
#
# - :py:func:`torchaudio.functional.rtf_evd`, which applies eigenvalue
# decomposition to the PSD matrix of target speech to get the RTF matrix.
#
# - :py:func:`torchaudio.functional.rtf_power`, which applies the power iteration
# method. You can specify the number of iterations with argument ``n_iter``.
#
rtf_evd = F.rtf_evd(psd_speech)
rtf_power = F.rtf_power(psd_speech, psd_noise, reference_channel=REFERENCE_CHANNEL)
ipd.Audio(results_multi["stv_power"], rate=16000)
######################################################################
# Single-channel mask, ref_channel solution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 6.2. Apply beamforming
# ~~~~~~~~~~~~~~~~~~~~~~
#
# :py:func:`torchaudio.transforms.RTFMVDR` takes the multi-channel
# complexed-valued STFT coefficients of the mixture speech, RTF matrix of target speech,
# PSD matrix of noise, and the reference channel inputs.
#
# The output is a single-channel complex-valued STFT coefficients of the enhanced speech.
# We can then obtain the enhanced waveform by passing this output to the
# :py:func:`torchaudio.transforms.InverseSpectrogram` module.
mvdr_transform = torchaudio.transforms.RTFMVDR()
# compute the enhanced speech based on F.rtf_evd
stft_rtf_evd = mvdr_transform(stft_mix, rtf_evd, psd_noise, reference_channel=REFERENCE_CHANNEL)
waveform_rtf_evd = istft(stft_rtf_evd, length=waveform_mix.shape[-1])
# compute the enhanced speech based on F.rtf_power
stft_rtf_power = mvdr_transform(stft_mix, rtf_power, psd_noise, reference_channel=REFERENCE_CHANNEL)
waveform_rtf_power = istft(stft_rtf_power, length=waveform_mix.shape[-1])
ipd.Audio(results_single["ref_channel"], rate=16000)
######################################################################
# Single-channel mask, stv_evd solution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 6.3. Result for RTFMVDR with `rtf_evd`
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
ipd.Audio(results_single["stv_evd"], rate=16000)
plot_spectrogram(stft_rtf_evd, "Enhanced Spectrogram by RTFMVDR and F.rtf_evd (dB)")
waveform_rtf_evd = waveform_rtf_evd.reshape(1, -1)
print(f"Si-SNR score: {si_snr(waveform_rtf_evd, waveform_clean[0:1])}")
Audio(waveform_rtf_evd, rate=SAMPLE_RATE)
######################################################################
# Single-channel mask, stv_power solution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 6.4. Result for RTFMVDR with `rtf_power`
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
ipd.Audio(results_single["stv_power"], rate=16000)
plot_spectrogram(stft_rtf_power, "Enhanced Spectrogram by RTFMVDR and F.rtf_evd (dB)")
waveform_rtf_power = waveform_rtf_power.reshape(1, -1)
print(f"Si-SNR score: {si_snr(waveform_rtf_power, waveform_clean[0:1])}")
Audio(waveform_rtf_power, rate=SAMPLE_RATE)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment