{ "nbformat": 4, "nbformat_minor": 2, "metadata": { "colab": { "name": "Copy of Copy of torchaudio_MVDR_tutorial.ipynb", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3.9.6 64-bit ('dev': conda)" }, "language_info": { "name": "python", "version": "3.9.6", "mimetype": "text/x-python", "codemirror_mode": { "name": "ipython", "version": 3 }, "pygments_lexer": "ipython3", "nbconvert_exporter": "python", "file_extension": ".py" }, "interpreter": { "hash": "6a702c257b9a40163843ba760790c17a6ddd2abeef8febce55475eea4b92c28c" } }, "cells": [ { "cell_type": "markdown", "source": [ "\"Open" ], "metadata": { "id": "xheYDPUcYGbp" } }, { "cell_type": "markdown", "source": [ "This is a tutorial on how to apply MVDR beamforming by using [torchaudio](https://github.com/pytorch/audio)\n", "-----------\n", "\n", "The multi-channel audio example is selected from [ConferencingSpeech](https://github.com/ConferencingSpeech/ConferencingSpeech2021) dataset. \n", "\n", "```\n", "original filename: SSB07200001\\#noise-sound-bible-0038\\#7.86_6.16_3.00_3.14_4.84_134.5285_191.7899_0.4735\\#15217\\#25.16333303751458\\#0.2101221178590021.wav\n", "```\n", "\n", "Note:\n", "- You need to use the nightly torchaudio in order to use the MVDR and InverseSpectrogram modules.\n", "\n", "\n", "Steps\n", "\n", "- Ideal Ratio Mask (IRM) is generated by dividing the clean/noise magnitude by the mixture magnitude.\n", "- We test all three solutions (``ref_channel``, ``stv_evd``, ``stv_power``) of torchaudio's MVDR module.\n", "- We test the single-channel and multi-channel masks for MVDR beamforming. The multi-channel mask is averaged along channel dimension when computing the covariance matrices of speech and noise, respectively." ], "metadata": { "id": "L6R0MXe5Wr19" } }, { "cell_type": "code", "execution_count": null, "source": [ "!pip install --pre torchaudio -f https://download.pytorch.org/whl/nightly/torch_nightly.html --force" ], "outputs": [], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "juO6PE9XLctD", "outputId": "8777ba14-da99-4c18-d80f-b070ad9861af" } }, { "cell_type": "code", "execution_count": null, "source": [ "import torch\n", "import torchaudio\n", "import IPython.display as ipd" ], "outputs": [], "metadata": { "id": "T4u4unhFMMBG" } }, { "cell_type": "markdown", "source": [ "### Load audios of mixture, reverberated clean speech, and dry clean speech." ], "metadata": { "id": "bDILVXkeg2s3" } }, { "cell_type": "code", "execution_count": null, "source": [ "!curl -LJO https://github.com/nateanl/torchaudio_mvdr_tutorial/raw/main/wavs/mix.wav\n", "!curl -LJO https://github.com/nateanl/torchaudio_mvdr_tutorial/raw/main/wavs/reverb_clean.wav\n", "!curl -LJO https://github.com/nateanl/torchaudio_mvdr_tutorial/raw/main/wavs/clean.wav" ], "outputs": [], "metadata": { "id": "2XIyMa_VKv0c", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "404f46a6-e70c-4f80-af8d-d356408a9f18" } }, { "cell_type": "code", "execution_count": null, "source": [ "mix, sr = torchaudio.load('mix.wav')\n", "reverb_clean, sr2 = torchaudio.load('reverb_clean.wav')\n", "clean, sr3 = torchaudio.load('clean.wav')\n", "assert sr == sr2\n", "noise = mix - reverb_clean" ], "outputs": [], "metadata": { "id": "iErB6UhQPtD3" } }, { "cell_type": "markdown", "source": [ "## Note: The MVDR Module requires ``torch.cdouble`` dtype for noisy STFT. We need to convert the dtype of the waveforms to ``torch.double``" ], "metadata": { "id": "Aq-x_fo5VkwL" } }, { "cell_type": "code", "execution_count": null, "source": [ "mix = mix.to(torch.double)\n", "noise = noise.to(torch.double)\n", "clean = clean.to(torch.double)\n", "reverb_clean = reverb_clean.to(torch.double)" ], "outputs": [], "metadata": { "id": "5c66pHcQV0P9" } }, { "cell_type": "markdown", "source": [ "### Initilize the Spectrogram and InverseSpectrogram modules" ], "metadata": { "id": "05D26we0V4P-" } }, { "cell_type": "code", "execution_count": null, "source": [ "stft = torchaudio.transforms.Spectrogram(n_fft=1024, hop_length=256, return_complex=True, power=None)\n", "istft = torchaudio.transforms.InverseSpectrogram(n_fft=1024, hop_length=256)" ], "outputs": [], "metadata": { "id": "NcGhD7_TUKd1" } }, { "cell_type": "markdown", "source": [ "### Compute the complex-valued STFT of mixture, clean speech, and noise" ], "metadata": { "id": "-dlJcuSNUCgA" } }, { "cell_type": "code", "execution_count": null, "source": [ "spec_mix = stft(mix)\n", "spec_clean = stft(clean)\n", "spec_reverb_clean = stft(reverb_clean)\n", "spec_noise = stft(noise)" ], "outputs": [], "metadata": { "id": "w1vO7w1BUKt4" } }, { "cell_type": "markdown", "source": [ "### Generate the Ideal Ratio Mask (IRM)\n", "Note: we found using the mask directly peforms better than using the square root of it. This is slightly different from the definition of IRM." ], "metadata": { "id": "8SBchrDhURK1" } }, { "cell_type": "code", "execution_count": null, "source": [ "def get_irms(spec_clean, spec_noise, spec_mix):\n", " mag_mix = spec_mix.abs() ** 2\n", " mag_clean = spec_clean.abs() ** 2\n", " mag_noise = spec_noise.abs() ** 2\n", " irm_speech = mag_clean / (mag_clean + mag_noise)\n", " irm_noise = mag_noise / (mag_clean + mag_noise)\n", "\n", " return irm_speech, irm_noise" ], "outputs": [], "metadata": { "id": "2gB63BoWUmHZ" } }, { "cell_type": "markdown", "source": [ "## Note: We use reverberant clean speech as the target here, you can also set it to dry clean speech" ], "metadata": { "id": "reGMDyNCaE7L" } }, { "cell_type": "code", "execution_count": null, "source": [ "irm_speech, irm_noise = get_irms(spec_reverb_clean, spec_noise, spec_mix)" ], "outputs": [], "metadata": { "id": "HSTCGy_5Uqzx" } }, { "cell_type": "markdown", "source": [ "### Apply MVDR beamforming by using multi-channel masks" ], "metadata": { "id": "1R5I_TmSUbS0" } }, { "cell_type": "code", "execution_count": null, "source": [ "results_multi = {}\n", "for solution in ['ref_channel', 'stv_evd', 'stv_power']:\n", " mvdr = torchaudio.transforms.MVDR(ref_channel=0, solution=solution, multi_mask=True)\n", " stft_est = mvdr(spec_mix, irm_speech, irm_noise)\n", " est = istft(stft_est, length=mix.shape[-1])\n", " results_multi[solution] = est" ], "outputs": [], "metadata": { "id": "SiWFZgCbadz7" } }, { "cell_type": "markdown", "source": [ "### Apply MVDR beamforming by using single-channel masks \n", "(We use the 1st channel as an example. The channel selection may depend on the design of the microphone array)" ], "metadata": { "id": "Ukez6_lcUfna" } }, { "cell_type": "code", "execution_count": null, "source": [ "results_single = {}\n", "for solution in ['ref_channel', 'stv_evd', 'stv_power']:\n", " mvdr = torchaudio.transforms.MVDR(ref_channel=0, solution=solution, multi_mask=False)\n", " stft_est = mvdr(spec_mix, irm_speech[0], irm_noise[0])\n", " est = istft(stft_est, length=mix.shape[-1])\n", " results_single[solution] = est" ], "outputs": [], "metadata": { "id": "kLeNKsk-VLm5" } }, { "cell_type": "markdown", "source": [ "### Compute Si-SDR scores" ], "metadata": { "id": "uJjJNdYiUnf0" } }, { "cell_type": "code", "execution_count": null, "source": [ "def si_sdr(estimate, reference, epsilon=1e-8):\n", " estimate = estimate - estimate.mean()\n", " reference = reference - reference.mean()\n", " reference_pow = reference.pow(2).mean(axis=1, keepdim=True)\n", " mix_pow = (estimate * reference).mean(axis=1, keepdim=True)\n", " scale = mix_pow / (reference_pow + epsilon)\n", "\n", " reference = scale * reference\n", " error = estimate - reference\n", "\n", " reference_pow = reference.pow(2)\n", " error_pow = error.pow(2)\n", "\n", " reference_pow = reference_pow.mean(axis=1)\n", " error_pow = error_pow.mean(axis=1)\n", "\n", " sisdr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)\n", " return sisdr.item()" ], "outputs": [], "metadata": { "id": "MgmAJcyiU-FU" } }, { "cell_type": "markdown", "source": [ "### Single-channel mask results" ], "metadata": { "id": "3TCJEwTOUxci" } }, { "cell_type": "code", "execution_count": null, "source": [ "for solution in results_single:\n", " print(solution+\": \", si_sdr(results_single[solution][None,...], reverb_clean[0:1]))" ], "outputs": [], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "NrUXXj98VVY7", "outputId": "bc113347-70e3-47a9-8479-8aeeeca80abf" } }, { "cell_type": "markdown", "source": [ "### Multi-channel mask results" ], "metadata": { "id": "-7AnjM-gU3c8" } }, { "cell_type": "code", "execution_count": null, "source": [ "for solution in results_multi:\n", " print(solution+\": \", si_sdr(results_multi[solution][None,...], reverb_clean[0:1]))" ], "outputs": [], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "S_VINTnlXobM", "outputId": "234b5615-63e7-44d8-f816-a6cc05999e52" } }, { "cell_type": "markdown", "source": [ "### Display the mixture audio" ], "metadata": { "id": "_vOK8vgmU_UP" } }, { "cell_type": "code", "execution_count": null, "source": [ "print(\"Mixture speech\")\n", "ipd.Audio(mix[0], rate=16000)" ], "outputs": [], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 92 }, "id": "QaKauQIHYctE", "outputId": "674c7f9b-62a3-4298-81ac-d3ab1ee43cd7" } }, { "cell_type": "markdown", "source": [ "### Display the noise" ], "metadata": { "id": "R-QGGm87VFQI" } }, { "cell_type": "code", "execution_count": null, "source": [ "print(\"Noise\")\n", "ipd.Audio(noise[0], rate=16000)" ], "outputs": [], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 92 }, "id": "l1WgzxIZYhlk", "outputId": "7b100679-b4a0-47ff-b30b-9f4cb9dca3d1" } }, { "cell_type": "markdown", "source": [ "### Display the clean speech" ], "metadata": { "id": "P3kB-jzpVKKu" } }, { "cell_type": "code", "execution_count": null, "source": [ "print(\"Clean speech\")\n", "ipd.Audio(clean[0], rate=16000)" ], "outputs": [], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 92 }, "id": "pwAWvlRAVJkT", "outputId": "5e173a1b-2ba8-4797-8f3a-e41cbf05ac2b" } }, { "cell_type": "markdown", "source": [ "### Display the enhanced audios¶" ], "metadata": { "id": "RIlyzL1wVTnr" } }, { "cell_type": "code", "execution_count": null, "source": [ "print(\"multi-channel mask, ref_channel solution\")\n", "ipd.Audio(results_multi['ref_channel'], rate=16000)" ], "outputs": [], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 92 }, "id": "M3YQsledVIQ5", "outputId": "43d9ee34-6933-401b-baf9-e4cdb7d79b63" } }, { "cell_type": "code", "execution_count": null, "source": [ "print(\"multi-channel mask, stv_evd solution\")\n", "ipd.Audio(results_multi['stv_evd'], rate=16000)" ], "outputs": [], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 92 }, "id": "UhYOHLvCVWBN", "outputId": "761468ec-ebf9-4b31-ad71-bfa2e15fed37" } }, { "cell_type": "code", "execution_count": null, "source": [ "print(\"multi-channel mask, stv_power solution\")\n", "ipd.Audio(results_multi['stv_power'], rate=16000)" ], "outputs": [], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 92 }, "id": "9dv8VDtCVXzd", "outputId": "1ae61ea3-d3c4-479f-faad-7439f942aac1" } }, { "cell_type": "code", "execution_count": null, "source": [ "print(\"single-channel mask, ref_channel solution\")\n", "ipd.Audio(results_single['ref_channel'], rate=16000)" ], "outputs": [], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 92 }, "id": "jCFUN890VZdh", "outputId": "c0d2a928-5dd0-4584-b277-7838ac4a9e6b" } }, { "cell_type": "code", "execution_count": null, "source": [ "print(\"single-channel mask, stv_evd solution\")\n", "ipd.Audio(results_single['stv_evd'], rate=16000)" ], "outputs": [], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 92 }, "id": "hzlzagsKVbAv", "outputId": "96af9e37-82ca-4544-9c08-421fe222bde4" } }, { "cell_type": "code", "execution_count": null, "source": [ "print(\"single-channel mask, stv_power solution\")\n", "ipd.Audio(results_single['stv_power'], rate=16000)" ], "outputs": [], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 92 }, "id": "A4igQpTnVctG", "outputId": "cf968089-9274-4c1c-a1a5-32b220de0bf9" } } ] }