Commit b7d2d928 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Move TorchAudio-Squim models to Beta (#3512)

Summary:
The PR move `SquimObjective` and `SquimSubjective` models and corresponding factory functions and pre-trained pipelines out of prototype and to the core directory. They will be included in the next official release.

Pull Request resolved: https://github.com/pytorch/audio/pull/3512

Reviewed By: mthrok

Differential Revision: D47837434

Pulled By: nateanl

fbshipit-source-id: d0639f29079f7e1afc30f236849e530c8cadffd8
parent d6aeaa74
......@@ -46,6 +46,14 @@
"hdemucs_medium",
"hdemucs_high",
],
"torchaudio.models.SquimObjective": [
"squim_objective_model",
"squim_objective_base",
],
"torchaudio.models.SquimSubjective": [
"squim_subjective_model",
"squim_subjective_base",
],
}
-%}
{%- set prototype_factory = {
......
......@@ -13,14 +13,6 @@
}
-%}
{%- set factory={
"torchaudio.prototype.models.SquimObjective": [
"squim_objective_model",
"squim_objective_base",
],
"torchaudio.prototype.models.SquimSubjective": [
"squim_subjective_model",
"squim_subjective_base",
],
"torchaudio.prototype.models.ConformerWav2Vec2PretrainModel": [
"conformer_wav2vec2_pretrain_model",
"conformer_wav2vec2_pretrain_base",
......
......@@ -28,6 +28,8 @@ For such models, factory functions are provided.
HuBERTPretrainModel
RNNT
RNNTBeamSearch
SquimObjective
SquimSubjective
Tacotron2
Wav2Letter
Wav2Vec2Model
......
......@@ -217,3 +217,53 @@ Pretrained Models
CONVTASNET_BASE_LIBRI2MIX
HDEMUCS_HIGH_MUSDB_PLUS
HDEMUCS_HIGH_MUSDB
Squim Objective
---------------
Interface
~~~~~~~~~
:py:class:`SquimObjectiveBundle` defines speech quality and intelligibility measurement (SQUIM) pipeline that can predict **objecive** metric scores given the input waveform.
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_class.rst
SquimObjectiveBundle
Pretrained Models
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_data.rst
SQUIM_OBJECTIVE
Squim Subjective
----------------
Interface
~~~~~~~~~
:py:class:`SquimSubjectiveBundle` defines speech quality and intelligibility measurement (SQUIM) pipeline that can predict **subjective** metric scores given the input waveform.
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_class.rst
SquimSubjectiveBundle
Pretrained Models
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_data.rst
SQUIM_SUBJECTIVE
......@@ -23,8 +23,6 @@ For such models, factory functions are provided.
ConformerWav2Vec2PretrainModel
ConvEmformer
HiFiGANVocoder
SquimObjective
SquimSubjective
Prototype Factory Functions of Beta Models
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -45,53 +45,3 @@ Pretrained Models
:template: autosummary/bundle_data.rst
HIFIGAN_VOCODER_V3_LJSPEECH
Squim Objective
---------------
Interface
~~~~~~~~~
:py:class:`SquimObjectiveBundle` defines speech quality and intelligibility measurement (SQUIM) pipeline that can predict **objecive** metric scores given the input waveform.
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_class.rst
SquimObjectiveBundle
Pretrained Models
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_data.rst
SQUIM_OBJECTIVE
Squim Subjective
----------------
Interface
~~~~~~~~~
:py:class:`SquimSubjectiveBundle` defines speech quality and intelligibility measurement (SQUIM) pipeline that can predict **subjective** metric scores given the input waveform.
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_class.rst
SquimSubjectiveBundle
Pretrained Models
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_data.rst
SQUIM_SUBJECTIVE
......@@ -80,7 +80,7 @@ print(torchaudio.__version__)
try:
from pesq import pesq
from pystoi import stoi
from torchaudio.prototype.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE
from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE
except ImportError:
import google.colab # noqa: F401
......
import pytest
import torchaudio
from torchaudio.prototype.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE
from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE
@pytest.mark.parametrize(
......
import torch
from parameterized import parameterized
from torchaudio.prototype.models import squim_objective_base, squim_subjective_base
from torchaudio.models import squim_objective_base, squim_subjective_base
from torchaudio_unittest.common_utils import skipIfNoCuda, torch_script, TorchaudioTestCase
......
......@@ -5,6 +5,14 @@ from .deepspeech import DeepSpeech
from .emformer import Emformer
from .rnnt import emformer_rnnt_base, emformer_rnnt_model, RNNT
from .rnnt_decoder import Hypothesis, RNNTBeamSearch
from .squim import (
squim_objective_base,
squim_objective_model,
squim_subjective_base,
squim_subjective_model,
SquimObjective,
SquimSubjective,
)
from .tacotron2 import Tacotron2
from .wav2letter import Wav2Letter
from .wav2vec2 import (
......@@ -68,4 +76,10 @@ __all__ = [
"hdemucs_low",
"hdemucs_medium",
"hdemucs_high",
"squim_objective_base",
"squim_objective_model",
"squim_subjective_base",
"squim_subjective_model",
"SquimObjective",
"SquimSubjective",
]
......@@ -4,6 +4,7 @@ from ._source_separation_pipeline import (
HDEMUCS_HIGH_MUSDB_PLUS,
SourceSeparationBundle,
)
from ._squim_pipeline import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE, SquimObjectiveBundle, SquimSubjectiveBundle
from ._tts import (
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
......@@ -90,4 +91,8 @@ __all__ = [
"CONVTASNET_BASE_LIBRI2MIX",
"HDEMUCS_HIGH_MUSDB_PLUS",
"HDEMUCS_HIGH_MUSDB",
"SQUIM_OBJECTIVE",
"SQUIM_SUBJECTIVE",
"SquimObjectiveBundle",
"SquimSubjectiveBundle",
]
......@@ -2,13 +2,13 @@ from dataclasses import dataclass
from torchaudio._internal import load_state_dict_from_url
from torchaudio.prototype.models import squim_objective_base, squim_subjective_base, SquimObjective, SquimSubjective
from torchaudio.models import squim_objective_base, squim_subjective_base, SquimObjective, SquimSubjective
@dataclass
class SquimObjectiveBundle:
"""Data class that bundles associated information to use pretrained
:py:class:`~torchaudio.prototype.models.SquimObjective` model.
:py:class:`~torchaudio.models.SquimObjective` model.
This class provides interfaces for instantiating the pretrained model along with
the information necessary to retrieve pretrained weights and additional data
......@@ -24,8 +24,7 @@ class SquimObjectiveBundle:
Example: Estimate the objective metric scores for the input waveform.
>>> import torch
>>> import torchaudio
>>> # Since SquimObjective bundle is in prototypes, it needs to be exported explicitly
>>> from torchaudio.prototype.pipelines import SQUIM_OBJECTIVE as bundle
>>> from torchaudio.pipelines import SQUIM_OBJECTIVE as bundle
>>>
>>> # Load the SquimObjective bundle
>>> model = bundle.get_model()
......@@ -59,7 +58,7 @@ class SquimObjectiveBundle:
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
Returns:
Variation of :py:class:`~torchaudio.prototype.models.SquimObjective`.
Variation of :py:class:`~torchaudio.models.SquimObjective`.
"""
model = squim_objective_base()
model.load_state_dict(self._get_state_dict(dl_kwargs))
......@@ -82,7 +81,7 @@ SQUIM_OBJECTIVE = SquimObjectiveBundle(
SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in
:cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.
The underlying model is constructed by :py:func:`torchaudio.prototype.models.squim_objective_base`.
The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.
The weights are under `Creative Commons Attribution 4.0 International License
<https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE>`__.
......@@ -93,7 +92,7 @@ SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach desc
@dataclass
class SquimSubjectiveBundle:
"""Data class that bundles associated information to use pretrained
:py:class:`~torchaudio.prototype.models.SquimSubjective` model.
:py:class:`~torchaudio.models.SquimSubjective` model.
This class provides interfaces for instantiating the pretrained model along with
the information necessary to retrieve pretrained weights and additional data
......@@ -109,8 +108,7 @@ class SquimSubjectiveBundle:
Example: Estimate the subjective metric scores for the input waveform.
>>> import torch
>>> import torchaudio
>>> # Since SquimSubjective bundle is in prototypes, it needs to be exported explicitly
>>> from torchaudio.prototype.pipelines import SQUIM_SUBJECTIVE as bundle
>>> from torchaudio.pipelines import SQUIM_SUBJECTIVE as bundle
>>>
>>> # Load the SquimSubjective bundle
>>> model = bundle.get_model()
......@@ -146,7 +144,7 @@ class SquimSubjectiveBundle:
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
Returns:
Variation of :py:class:`~torchaudio.prototype.models.SquimObjective`.
Variation of :py:class:`~torchaudio.models.SquimObjective`.
"""
model = squim_subjective_base()
model.load_state_dict(self._get_state_dict(dl_kwargs))
......@@ -170,7 +168,7 @@ SQUIM_SUBJECTIVE.__doc__ = """SquimSubjective pipeline trained
as described in :cite:`manocha2022speech` and :cite:`kumar2023torchaudio`
on the *BVCC* :cite:`cooper2021voices` and *DAPS* :cite:`mysore2014can` datasets.
The underlying model is constructed by :py:func:`torchaudio.prototype.models.squim_subjective_base`.
The underlying model is constructed by :py:func:`torchaudio.models.squim_subjective_base`.
The weights are under `Creative Commons Attribution Non Commercial 4.0 International
<https://zenodo.org/record/4660670#.ZBtWPOxuerN>`__.
......
......@@ -11,14 +11,6 @@ from .conv_emformer import ConvEmformer
from .hifi_gan import hifigan_vocoder, hifigan_vocoder_v1, hifigan_vocoder_v2, hifigan_vocoder_v3, HiFiGANVocoder
from .rnnt import conformer_rnnt_base, conformer_rnnt_biasing, conformer_rnnt_biasing_base, conformer_rnnt_model
from .rnnt_decoder import Hypothesis, RNNTBeamSearchBiasing
from .squim import (
squim_objective_base,
squim_objective_model,
squim_subjective_base,
squim_subjective_model,
SquimObjective,
SquimSubjective,
)
__all__ = [
"conformer_rnnt_base",
......@@ -42,10 +34,4 @@ __all__ = [
"hifigan_vocoder_v2",
"hifigan_vocoder_v3",
"hifigan_vocoder",
"squim_objective_base",
"squim_objective_model",
"squim_subjective_base",
"squim_subjective_model",
"SquimObjective",
"SquimSubjective",
]
from .hifigan_pipeline import HIFIGAN_VOCODER_V3_LJSPEECH, HiFiGANVocoderBundle
from .rnnt_pipeline import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
from .squim_pipeline import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE, SquimObjectiveBundle, SquimSubjectiveBundle
__all__ = [
"EMFORMER_RNNT_BASE_MUSTC",
"EMFORMER_RNNT_BASE_TEDLIUM3",
"HIFIGAN_VOCODER_V3_LJSPEECH",
"HiFiGANVocoderBundle",
"SQUIM_OBJECTIVE",
"SQUIM_SUBJECTIVE",
"SquimObjectiveBundle",
"SquimSubjectiveBundle",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment