Commit 68fa1d3f authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add SquimSubjective pre-trained pipeline (#3197)

Summary:
The PR adds the pre-trained pipeline for `SquimSubjective` model which predicts MOS score for speech enhancement task.

Pull Request resolved: https://github.com/pytorch/audio/pull/3197

Reviewed By: mthrok

Differential Revision: D44313244

Pulled By: nateanl

fbshipit-source-id: 905095ff77006e9f441faa826fc25d9d8681e8aa
parent 92eff154
......@@ -97,6 +97,6 @@ Pre-trained Model License
The pre-trained models provided in this library may have their own licenses or terms and conditions derived from the dataset used for training. It is your responsibility to determine whether you have permission to use the models for your use case.
For instance, SquimObjective model is released under the Creative Commons Attribution 4.0 International license. See [DNS 2020 license](https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE) for additional details.
For instance, SquimSubjective model is released under the Creative Commons Attribution Non Commercial 4.0 International (CC-BY-NC 4.0) license. See [the link](https://zenodo.org/record/4660670#.ZBtWPOxuerN) for additional details.
Other pre-trained models that have different license are noted in documentation. Please checkout the [documentation page](https://pytorch.org/audio/main/).
......@@ -70,3 +70,28 @@ Pretrained Models
:template: autosummary/bundle_data.rst
SQUIM_OBJECTIVE
Squim Subjective
----------------
Interface
~~~~~~~~~
:py:class:`SquimSubjectiveBundle` defines speech quality and intelligibility measurement (SQUIM) pipeline that can predict **subjective** metric scores given the input waveform.
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_class.rst
SquimSubjectiveBundle
Pretrained Models
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_data.rst
SQUIM_SUBJECTIVE
......@@ -540,3 +540,19 @@ abstract = {End-to-end spoken language translation (SLT) has recently gained pop
journal={arXiv preprint arXiv:2206.12285},
year={2022}
}
@article{cooper2021voices,
title={How do voices from past speech synthesis challenges compare today?},
author={Cooper, Erica and Yamagishi, Junichi},
journal={arXiv preprint arXiv:2105.02373},
year={2021}
}
@article{mysore2014can,
title={Can we automatically transform speech recorded on common consumer devices in real-world environments into professional production quality speech?—a dataset, insights, and challenges},
author={Mysore, Gautham J},
journal={IEEE Signal Processing Letters},
volume={22},
number={8},
pages={1006--1010},
year={2014},
publisher={IEEE}
}
import pytest
import torchaudio
from torchaudio.prototype.pipelines import SQUIM_OBJECTIVE
from torchaudio.prototype.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE
@pytest.mark.parametrize(
......@@ -20,3 +20,24 @@ def test_squim_objective_pretrained_weights(lang, expected, sample_speech):
scores = model(waveform)
for i in range(3):
assert abs(scores[i].item() - expected[i]) < 1e-5
@pytest.mark.parametrize(
"task,expected",
[
("speech_separation", [3.9257140159606934, 3.9391300678253174]),
],
)
def test_squim_subjective_pretrained_weights(task, expected, mixture_source, clean_sources):
"""Test that the metric scores estimated by SquimSubjective Bundle is identical to the expected result."""
bundle = SQUIM_SUBJECTIVE
# Get SquimObjective model
model = bundle.get_model()
# Load input mixture audio
waveform, sample_rate = torchaudio.load(mixture_source)
for i, source in enumerate(clean_sources):
# Load clean reference
clean_waveform, sample_rate = torchaudio.load(source)
score = model(waveform, clean_waveform)
assert abs(score.item() - expected[i]) < 1e-5
......@@ -130,7 +130,7 @@ def squim_subjective_model(
Args:
ssl_type (str): Type of self-supervised learning (SSL) models.
Must be one of ["wav2vec2_base", "wav2vec2_large"].
feature_dim (int): Feature dimension of the SSL feature representation.
feat_dim (int): Feature dimension of the SSL feature representation.
proj_dim (int): Output dimension of projection layer.
att_dim (int): Dimension of attention scores.
"""
......
from .hifigan_pipeline import HIFIGAN_VOCODER_V3_LJSPEECH, HiFiGANVocoderBundle
from .rnnt_pipeline import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
from .squim_pipeline import SQUIM_OBJECTIVE, SquimObjectiveBundle
from .squim_pipeline import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE, SquimObjectiveBundle, SquimSubjectiveBundle
__all__ = [
"EMFORMER_RNNT_BASE_MUSTC",
......@@ -8,5 +8,7 @@ __all__ = [
"HIFIGAN_VOCODER_V3_LJSPEECH",
"HiFiGANVocoderBundle",
"SQUIM_OBJECTIVE",
"SQUIM_SUBJECTIVE",
"SquimObjectiveBundle",
"SquimSubjectiveBundle",
]
......@@ -2,7 +2,7 @@ from dataclasses import dataclass
from torchaudio._internal import load_state_dict_from_url
from torchaudio.prototype.models import squim_objective_base, SquimObjective
from torchaudio.prototype.models import squim_objective_base, squim_subjective_base, SquimObjective, SquimSubjective
@dataclass
......@@ -88,3 +88,90 @@ SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline, trained on the *DNS 2020 D
Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.
"""
@dataclass
class SquimSubjectiveBundle:
"""Data class that bundles associated information to use pretrained
:py:class:`~torchaudio.prototype.models.SquimSubjective` model.
This class provides interfaces for instantiating the pretrained model along with
the information necessary to retrieve pretrained weights and additional data
to be used with the model.
Torchaudio library instantiates objects of this class, each of which represents
a different pretrained model. Client code should access pretrained models via these
instances.
This bundle can estimate subjective metric scores for speech enhancement, such as MOS.
A typical use case would be a flow like `waveform -> score`. Please see below for the code example.
Example: Estimate the subjective metric scores for the input waveform.
>>> import torch
>>> import torchaudio
>>> # Since SquimSubjective bundle is in prototypes, it needs to be exported explicitly
>>> from torchaudio.prototype.pipelines import SQUIM_SUBJECTIVE as bundle
>>>
>>> # Load the SquimSubjective bundle
>>> model = bundle.get_model()
Downloading: "https://download.pytorch.org/torchaudio/models/squim_subjective_bvcc_daps.pth"
100%|████████████| 360M/360M [00:09<00:00, 41.1MB/s]
>>>
>>> # Resample audio to the expected sampling rate
>>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
>>> # Use a clean reference (doesn't need to be the reference for the waveform) as the second input
>>> reference = torchaudio.functional.resample(reference, sample_rate, bundle.sample_rate)
>>>
>>> # Estimate subjective metric scores
>>> score = model(waveform, reference)
>>> print(f"MOS: {score}.")
""" # noqa: E501
_path: str
_sample_rate: float
def _get_state_dict(self, dl_kwargs):
url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs)
return state_dict
def get_model(self, *, dl_kwargs=None) -> SquimSubjective:
"""Construct the SquimSubjective model, and load the pretrained weight.
The weight file is downloaded from the internet and cached with
:func:`torch.hub.load_state_dict_from_url`
Args:
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
Returns:
Variation of :py:class:`~torchaudio.prototype.models.SquimObjective`.
"""
model = squim_subjective_base()
model.load_state_dict(self._get_state_dict(dl_kwargs))
model.eval()
return model
@property
def sample_rate(self):
"""Sample rate of the audio that the model is trained on.
:type: float
"""
return self._sample_rate
SQUIM_SUBJECTIVE = SquimSubjectiveBundle(
"squim_subjective_bvcc_daps.pth",
_sample_rate=16000,
)
SQUIM_SUBJECTIVE.__doc__ = """SquimSubjective pipeline, trained on the *BVCC*
:cite:`cooper2021voices` and *DAPS* :cite:`mysore2014can` datasets.
The underlying model is constructed by :py:func:`torchaudio.prototype.models.squim_subjective_base`.
The weights are under `Creative Commons Attribution Non Commercial 4.0 International
<https://zenodo.org/record/4660670#.ZBtWPOxuerN>`__.
Please refer to :py:class:`SquimSubjectiveBundle` for usage instructions.
"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment