Commit d6267031 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Improve speech enhancement tutorial (#2527)

Summary:
- The "speech + noise" mixture still has a high SNR, which can't show the effectiveness of MVDR beamforming. To make the task more challenging, amplify the noise waveform to reduce the SNR of mixture speech.
- Show the Si-SNR score of mixture speech when visualizing the mixture spectrogram.
- FIx the figure in `rtf_power` subsection.
    - The description of enhanced spectrogram by `rtf_power` is wrong. Correct it to `rtf_power`.
- Print PESQ, STOI, and SDR metric scores.

Pull Request resolved: https://github.com/pytorch/audio/pull/2527

Reviewed By: mthrok

Differential Revision: D38190218

Pulled By: nateanl

fbshipit-source-id: 39562850a67f58a16e0a2866ed95f78c3f4dc7de
parent 0fde7c57
...@@ -833,7 +833,8 @@ jobs: ...@@ -833,7 +833,8 @@ jobs:
conda install -v -y -c pytorch-${UPLOAD_CHANNEL} pytorch cudatoolkit=${CU_VERSION:2:2}.${CU_VERSION:4} -c conda-forge conda install -v -y -c pytorch-${UPLOAD_CHANNEL} pytorch cudatoolkit=${CU_VERSION:2:2}.${CU_VERSION:4} -c conda-forge
fi fi
conda install -v -y -c file://$HOME/workspace/conda-bld torchaudio conda install -v -y -c file://$HOME/workspace/conda-bld torchaudio
conda install -y pandoc 'ffmpeg<5' # gxx_linux-64 is for installing pesq library that depends on cython
conda install -y pandoc 'ffmpeg<5' gxx_linux-64
apt update -qq && apt-get -qq install -y git make apt update -qq && apt-get -qq install -y git make
pip install --progress-bar off -r docs/requirements.txt -r docs/requirements-tutorials.txt pip install --progress-bar off -r docs/requirements.txt -r docs/requirements-tutorials.txt
- run: - run:
......
...@@ -833,7 +833,8 @@ jobs: ...@@ -833,7 +833,8 @@ jobs:
conda install -v -y -c pytorch-${UPLOAD_CHANNEL} pytorch cudatoolkit=${CU_VERSION:2:2}.${CU_VERSION:4} -c conda-forge conda install -v -y -c pytorch-${UPLOAD_CHANNEL} pytorch cudatoolkit=${CU_VERSION:2:2}.${CU_VERSION:4} -c conda-forge
fi fi
conda install -v -y -c file://$HOME/workspace/conda-bld torchaudio conda install -v -y -c file://$HOME/workspace/conda-bld torchaudio
conda install -y pandoc 'ffmpeg<5' # gxx_linux-64 is for installing pesq library that depends on cython
conda install -y pandoc 'ffmpeg<5' gxx_linux-64
apt update -qq && apt-get -qq install -y git make apt update -qq && apt-get -qq install -y git make
pip install --progress-bar off -r docs/requirements.txt -r docs/requirements-tutorials.txt pip install --progress-bar off -r docs/requirements.txt -r docs/requirements-tutorials.txt
- run: - run:
......
IPython IPython
deep-phonemizer deep-phonemizer
boto3 boto3
cython
pandas pandas
librosa librosa
sentencepiece sentencepiece
nbsphinx nbsphinx
pandoc pandoc
mir_eval mir_eval
pesq
pystoi
...@@ -41,7 +41,33 @@ print(torchaudio.__version__) ...@@ -41,7 +41,33 @@ print(torchaudio.__version__)
# 2. Preparation # 2. Preparation
# -------------- # --------------
# #
# First, we import the necessary packages and retrieve the data.
######################################################################
# 2.1. Import the packages
# ~~~~~~~~~~~~~~~~~~~~~~~~
#
# First, we install and import the necessary packages.
#
# ``mir_eval``, ``pesq``, and ``pystoi`` packages are required for
# evaluating the speech enhancement performance.
#
# When running this example in notebook, install the following packages.
# !pip3 install mir_eval
# !pip3 install pesq
# !pip3 install pystoi
from pesq import pesq
from pystoi import stoi
import mir_eval
import matplotlib.pyplot as plt
from IPython.display import Audio
from torchaudio.utils import download_asset
######################################################################
# 2.2. Download audio data
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
# The multi-channel audio example is selected from # The multi-channel audio example is selected from
# `ConferencingSpeech <https://github.com/ConferencingSpeech/ConferencingSpeech2021>`__ # `ConferencingSpeech <https://github.com/ConferencingSpeech/ConferencingSpeech2021>`__
...@@ -61,17 +87,13 @@ print(torchaudio.__version__) ...@@ -61,17 +87,13 @@ print(torchaudio.__version__)
# International — CC BY 4.0) # International — CC BY 4.0)
# #
import matplotlib.pyplot as plt
from IPython.display import Audio
from torchaudio.utils import download_asset
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
SAMPLE_CLEAN = download_asset("tutorial-assets/mvdr/clean_speech.wav") SAMPLE_CLEAN = download_asset("tutorial-assets/mvdr/clean_speech.wav")
SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav") SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav")
###################################################################### ######################################################################
# 2.1. Helper functions # 2.3. Helper functions
# ~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~
# #
...@@ -115,6 +137,30 @@ def si_snr(estimate, reference, epsilon=1e-8): ...@@ -115,6 +137,30 @@ def si_snr(estimate, reference, epsilon=1e-8):
return si_snr.item() return si_snr.item()
def generate_mixture(waveform_clean, waveform_noise, target_snr):
power_clean_signal = waveform_clean.pow(2).mean()
power_noise_signal = waveform_noise.pow(2).mean()
current_snr = 10 * torch.log10(power_clean_signal / power_noise_signal)
waveform_noise *= 10 ** (-(target_snr - current_snr) / 20)
return waveform_clean + waveform_noise
def evaluate(estimate, reference):
si_snr_score = si_snr(estimate, reference)
(
sdr,
_,
_,
_,
) = mir_eval.separation.bss_eval_sources(reference.numpy(), estimate.numpy(), False)
pesq_mix = pesq(SAMPLE_RATE, estimate[0].numpy(), reference[0].numpy(), "wb")
stoi_mix = stoi(reference[0].numpy(), estimate[0].numpy(), SAMPLE_RATE, extended=False)
print(f"SDR score: {sdr[0]}")
print(f"Si-SNR score: {si_snr_score}")
print(f"PESQ score: {pesq_mix}")
print(f"STOI score: {stoi_mix}")
###################################################################### ######################################################################
# 3. Generate Ideal Ratio Masks (IRMs) # 3. Generate Ideal Ratio Masks (IRMs)
# ------------------------------------ # ------------------------------------
...@@ -129,8 +175,9 @@ def si_snr(estimate, reference, epsilon=1e-8): ...@@ -129,8 +175,9 @@ def si_snr(estimate, reference, epsilon=1e-8):
waveform_clean, sr = torchaudio.load(SAMPLE_CLEAN) waveform_clean, sr = torchaudio.load(SAMPLE_CLEAN)
waveform_noise, sr2 = torchaudio.load(SAMPLE_NOISE) waveform_noise, sr2 = torchaudio.load(SAMPLE_NOISE)
assert sr == sr2 == SAMPLE_RATE assert sr == sr2 == SAMPLE_RATE
# The mixture waveform is a combination of clean and noise waveforms # The mixture waveform is a combination of clean and noise waveforms with a desired SNR.
waveform_mix = waveform_clean + waveform_noise target_snr = 3
waveform_mix = generate_mixture(waveform_clean, waveform_noise, target_snr)
###################################################################### ######################################################################
...@@ -166,8 +213,17 @@ stft_noise = stft(waveform_noise) ...@@ -166,8 +213,17 @@ stft_noise = stft(waveform_noise)
# 3.2.1. Visualize mixture speech # 3.2.1. Visualize mixture speech
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# #
# We evaluate the quality of the mixture speech or the enhanced speech
# using the following three metrics:
#
# - signal-to-distortion ratio (SDR)
# - scale-invariant signal-to-noise ratio (Si-SNR, or Si-SDR in some papers)
# - Perceptual Evaluation of Speech Quality (PESQ)
# We also evaluate the intelligibility of the speech with the Short-Time Objective Intelligibility
# (STOI) metric.
plot_spectrogram(stft_mix[0], "Spectrogram of Mixture Speech (dB)") plot_spectrogram(stft_mix[0], "Spectrogram of Mixture Speech (dB)")
evaluate(waveform_mix[0:1], waveform_clean[0:1])
Audio(waveform_mix[0], rate=SAMPLE_RATE) Audio(waveform_mix[0], rate=SAMPLE_RATE)
...@@ -280,7 +336,7 @@ waveform_souden = istft(stft_souden, length=waveform_mix.shape[-1]) ...@@ -280,7 +336,7 @@ waveform_souden = istft(stft_souden, length=waveform_mix.shape[-1])
plot_spectrogram(stft_souden, "Enhanced Spectrogram by SoudenMVDR (dB)") plot_spectrogram(stft_souden, "Enhanced Spectrogram by SoudenMVDR (dB)")
waveform_souden = waveform_souden.reshape(1, -1) waveform_souden = waveform_souden.reshape(1, -1)
print(f"Si-SNR score: {si_snr(waveform_souden, waveform_clean[0:1])}") evaluate(waveform_souden, waveform_clean[0:1])
Audio(waveform_souden, rate=SAMPLE_RATE) Audio(waveform_souden, rate=SAMPLE_RATE)
...@@ -338,7 +394,7 @@ waveform_rtf_power = istft(stft_rtf_power, length=waveform_mix.shape[-1]) ...@@ -338,7 +394,7 @@ waveform_rtf_power = istft(stft_rtf_power, length=waveform_mix.shape[-1])
plot_spectrogram(stft_rtf_evd, "Enhanced Spectrogram by RTFMVDR and F.rtf_evd (dB)") plot_spectrogram(stft_rtf_evd, "Enhanced Spectrogram by RTFMVDR and F.rtf_evd (dB)")
waveform_rtf_evd = waveform_rtf_evd.reshape(1, -1) waveform_rtf_evd = waveform_rtf_evd.reshape(1, -1)
print(f"Si-SNR score: {si_snr(waveform_rtf_evd, waveform_clean[0:1])}") evaluate(waveform_rtf_evd, waveform_clean[0:1])
Audio(waveform_rtf_evd, rate=SAMPLE_RATE) Audio(waveform_rtf_evd, rate=SAMPLE_RATE)
...@@ -347,7 +403,7 @@ Audio(waveform_rtf_evd, rate=SAMPLE_RATE) ...@@ -347,7 +403,7 @@ Audio(waveform_rtf_evd, rate=SAMPLE_RATE)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
plot_spectrogram(stft_rtf_power, "Enhanced Spectrogram by RTFMVDR and F.rtf_evd (dB)") plot_spectrogram(stft_rtf_power, "Enhanced Spectrogram by RTFMVDR and F.rtf_power (dB)")
waveform_rtf_power = waveform_rtf_power.reshape(1, -1) waveform_rtf_power = waveform_rtf_power.reshape(1, -1)
print(f"Si-SNR score: {si_snr(waveform_rtf_power, waveform_clean[0:1])}") evaluate(waveform_rtf_power, waveform_clean[0:1])
Audio(waveform_rtf_power, rate=SAMPLE_RATE) Audio(waveform_rtf_power, rate=SAMPLE_RATE)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment