Commit 46fae2fe authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add SquimObjectiveBundle to prototype (#3103)

Summary:
Add pre-trained pipeline support for `SquimObjective` model. The pre-trained model is trained on DNS 2020 challenge dataset.

Pull Request resolved: https://github.com/pytorch/audio/pull/3103

Reviewed By: xiaohui-zhang, mthrok

Differential Revision: D43611794

Pulled By: nateanl

fbshipit-source-id: 0ac76a27e7027a43ffccb158385ddb2409b8526d
parent bc61f109
...@@ -91,3 +91,12 @@ Disclaimer on Datasets ...@@ -91,3 +91,12 @@ Disclaimer on Datasets
This is a utility library that downloads and prepares public datasets. We do not host or distribute these datasets, vouch for their quality or fairness, or claim that you have license to use the dataset. It is your responsibility to determine whether you have permission to use the dataset under the dataset's license. This is a utility library that downloads and prepares public datasets. We do not host or distribute these datasets, vouch for their quality or fairness, or claim that you have license to use the dataset. It is your responsibility to determine whether you have permission to use the dataset under the dataset's license.
If you're a dataset owner and wish to update any part of it (description, citation, etc.), or do not want your dataset to be included in this library, please get in touch through a GitHub issue. Thanks for your contribution to the ML community! If you're a dataset owner and wish to update any part of it (description, citation, etc.), or do not want your dataset to be included in this library, please get in touch through a GitHub issue. Thanks for your contribution to the ML community!
Pre-trained Model License
-------------------------
The pre-trained models provided in this library may have their own licenses or terms and conditions derived from the dataset used for training. It is your responsibility to determine whether you have permission to use the models for your use case.
For instance, SquimObjective model is released under the Creative Commons Attribution 4.0 International license. See [DNS 2020 license](https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE) for additional details.
Other pre-trained models that have different license are noted in documentation. Please checkout the [documentation page](https://pytorch.org/audio/main/).
...@@ -50,3 +50,28 @@ Pretrained Models ...@@ -50,3 +50,28 @@ Pretrained Models
:template: autosummary/bundle_data.rst :template: autosummary/bundle_data.rst
HIFIGAN_VOCODER_V3_LJSPEECH HIFIGAN_VOCODER_V3_LJSPEECH
Squim Objective
---------------
Interface
~~~~~~~~~
:py:class:`SquimObjectiveBundle` defines speech quality and intelligibility measurement (SQUIM) pipeline that can predict **objecive** metric scores given the input waveform.
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_class.rst
SquimObjectiveBundle
Pretrained Models
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_data.rst
SQUIM_OBJECTIVE
...@@ -527,12 +527,10 @@ abstract = {End-to-end spoken language translation (SLT) has recently gained pop ...@@ -527,12 +527,10 @@ abstract = {End-to-end spoken language translation (SLT) has recently gained pop
title = "Absorption (acoustics) --- {W}ikipedia{,} The Free Encyclopedia", title = "Absorption (acoustics) --- {W}ikipedia{,} The Free Encyclopedia",
url = "https://en.wikipedia.org/wiki/Absorption_(acoustics)", url = "https://en.wikipedia.org/wiki/Absorption_(acoustics)",
note = "[Online]" note = "[Online]"
} }
@inproceedings{luo2020dual, @article{reddy2020interspeech,
title={Dual-path rnn: efficient long sequence modeling for time-domain single-channel speech separation}, title={The interspeech 2020 deep noise suppression challenge: Datasets, subjective testing framework, and challenge results},
author={Luo, Yi and Chen, Zhuo and Yoshioka, Takuya}, author={Reddy, Chandan KA and Gopal, Vishak and Cutler, Ross and Beyrami, Ebrahim and Cheng, Roger and Dubey, Harishchandra and Matusevych, Sergiy and Aichner, Robert and Aazami, Ashkan and Braun, Sebastian and others},
booktitle={ICASSP 2020-2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, journal={arXiv preprint arXiv:2005.13981},
pages={46--50}, year={2020}
year={2020},
organization={IEEE}
} }
import pytest
import torchaudio
from torchaudio.prototype.pipelines import SQUIM_OBJECTIVE
@pytest.mark.parametrize(
"lang,expected",
[
("en", [0.9978380799293518, 4.23893404006958, 24.217193603515625]),
],
)
def test_squim_objective_pretrained_weights(lang, expected, sample_speech):
"""Test that the metric scores estimated by SquimObjective Bundle is identical to the expected result."""
bundle = SQUIM_OBJECTIVE
# Get SquimObjective model
model = bundle.get_model()
# Create a synthetic waveform
waveform, sample_rate = torchaudio.load(sample_speech)
scores = model(waveform)
for i in range(3):
assert abs(scores[i].item() - expected[i]) < 1e-5
...@@ -25,6 +25,23 @@ class TestSQUIM(TorchaudioTestCase): ...@@ -25,6 +25,23 @@ class TestSQUIM(TorchaudioTestCase):
model = squim_objective_base() model = squim_objective_base()
self._smoke_test_objective(model, torch.device("cuda"), dtype) self._smoke_test_objective(model, torch.device("cuda"), dtype)
def test_batch_consistency(self):
model = squim_objective_base()
model.eval()
batch_size, num_frames = 3, 16000
waveforms = torch.randn(batch_size, num_frames)
ref_scores = model(waveforms)
hyp_scores = [torch.zeros(batch_size), torch.zeros(batch_size), torch.zeros(batch_size)]
for i in range(batch_size):
scores = model(waveforms[i : i + 1])
for j in range(3):
hyp_scores[j][i] = scores[j]
self.assertEqual(len(hyp_scores), len(ref_scores))
for i in range(len(ref_scores)):
self.assertEqual(hyp_scores[i], ref_scores[i])
def test_torchscript_consistency(self): def test_torchscript_consistency(self):
model = squim_objective_base() model = squim_objective_base()
model.eval() model.eval()
......
...@@ -205,8 +205,6 @@ class AutoPool(nn.Module): ...@@ -205,8 +205,6 @@ class AutoPool(nn.Module):
class SquimObjective(nn.Module): class SquimObjective(nn.Module):
"""Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores """Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores
for speech enhancement (e.g., STOI, PESQ, and SI-SDR). for speech enhancement (e.g., STOI, PESQ, and SI-SDR).
The model uses *dual-path recurrent neural networks (DPRNN)* :cite:`luo2020dual` to model sequential signals,
and multiple transformer branches to estimate the objective metric scores, respectively.
Args: Args:
encoder (torch.nn.Module): Encoder module to transform 1D waveform to 2D feature representation. encoder (torch.nn.Module): Encoder module to transform 1D waveform to 2D feature representation.
......
from .hifigan_pipeline import HIFIGAN_VOCODER_V3_LJSPEECH, HiFiGANVocoderBundle from .hifigan_pipeline import HIFIGAN_VOCODER_V3_LJSPEECH, HiFiGANVocoderBundle
from .rnnt_pipeline import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3 from .rnnt_pipeline import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
from .squim_pipeline import SQUIM_OBJECTIVE, SquimObjectiveBundle
__all__ = [ __all__ = [
"EMFORMER_RNNT_BASE_MUSTC", "EMFORMER_RNNT_BASE_MUSTC",
"EMFORMER_RNNT_BASE_TEDLIUM3", "EMFORMER_RNNT_BASE_TEDLIUM3",
"HIFIGAN_VOCODER_V3_LJSPEECH", "HIFIGAN_VOCODER_V3_LJSPEECH",
"HiFiGANVocoderBundle", "HiFiGANVocoderBundle",
"SQUIM_OBJECTIVE",
"SquimObjectiveBundle",
] ]
from dataclasses import dataclass
from torchaudio._internal import load_state_dict_from_url
from torchaudio.prototype.models import squim_objective_base, SquimObjective
@dataclass
class SquimObjectiveBundle:
"""Data class that bundles associated information to use pretrained
:py:class:`~torchaudio.prototype.models.SquimObjective` model.
This class provides interfaces for instantiating the pretrained model along with
the information necessary to retrieve pretrained weights and additional data
to be used with the model.
Torchaudio library instantiates objects of this class, each of which represents
a different pretrained model. Client code should access pretrained models via these
instances.
This bundle can estimate objective metric scores for speech enhancement, such as STOI, PESQ, Si-SDR.
A typical use case would be a flow like `waveform -> list of scores`. Please see below for the code example.
Example: Estimate the objective metric scores for the input waveform.
>>> import torch
>>> import torchaudio
>>> # Since SquimObjective bundle is in prototypes, it needs to be exported explicitly
>>> from torchaudio.prototype.pipelines import SQUIM_OBJECTIVE as bundle
>>>
>>> # Load the SquimObjective bundle
>>> model = bundle.get_model()
Downloading: "https://download.pytorch.org/torchaudio/models/squim_objective_dns2020.pth"
100%|████████████| 28.2M/28.2M [00:03<00:00, 9.24MB/s]
>>>
>>> # Resample audio to the expected sampling rate
>>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
>>>
>>> # Estimate objective metric scores
>>> scores = model(waveform)
>>> print(f"STOI: {scores[0].item()}, PESQ: {scores[1].item()}, SI-SDR: {scores[2].item()}.")
""" # noqa: E501
_path: str
_sample_rate: float
def _get_state_dict(self, dl_kwargs):
url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs)
return state_dict
def get_model(self, *, dl_kwargs=None) -> SquimObjective:
"""Construct the SquimObjective model, and load the pretrained weight.
The weight file is downloaded from the internet and cached with
:func:`torch.hub.load_state_dict_from_url`
Args:
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
Returns:
Variation of :py:class:`~torchaudio.prototype.models.SquimObjective`.
"""
model = squim_objective_base()
model.load_state_dict(self._get_state_dict(dl_kwargs))
model.eval()
return model
@property
def sample_rate(self):
"""Sample rate of the audio that the model is trained on.
:type: float
"""
return self._sample_rate
SQUIM_OBJECTIVE = SquimObjectiveBundle(
"squim_objective_dns2020.pth",
_sample_rate=16000,
)
SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline, trained on the *DNS 2020 Dataset*
:cite:`reddy2020interspeech`.
The underlying model is constructed by :py:func:`torchaudio.prototype.models.squim_objective_base`.
The weights are under `Creative Commons Attribution 4.0 International License
<https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE>`__.
Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.
"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment