Commit b7d2d928 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Move TorchAudio-Squim models to Beta (#3512)

Summary:
The PR move `SquimObjective` and `SquimSubjective` models and corresponding factory functions and pre-trained pipelines out of prototype and to the core directory. They will be included in the next official release.

Pull Request resolved: https://github.com/pytorch/audio/pull/3512

Reviewed By: mthrok

Differential Revision: D47837434

Pulled By: nateanl

fbshipit-source-id: d0639f29079f7e1afc30f236849e530c8cadffd8
parent d6aeaa74
...@@ -46,6 +46,14 @@ ...@@ -46,6 +46,14 @@
"hdemucs_medium", "hdemucs_medium",
"hdemucs_high", "hdemucs_high",
], ],
"torchaudio.models.SquimObjective": [
"squim_objective_model",
"squim_objective_base",
],
"torchaudio.models.SquimSubjective": [
"squim_subjective_model",
"squim_subjective_base",
],
} }
-%} -%}
{%- set prototype_factory = { {%- set prototype_factory = {
......
...@@ -13,14 +13,6 @@ ...@@ -13,14 +13,6 @@
} }
-%} -%}
{%- set factory={ {%- set factory={
"torchaudio.prototype.models.SquimObjective": [
"squim_objective_model",
"squim_objective_base",
],
"torchaudio.prototype.models.SquimSubjective": [
"squim_subjective_model",
"squim_subjective_base",
],
"torchaudio.prototype.models.ConformerWav2Vec2PretrainModel": [ "torchaudio.prototype.models.ConformerWav2Vec2PretrainModel": [
"conformer_wav2vec2_pretrain_model", "conformer_wav2vec2_pretrain_model",
"conformer_wav2vec2_pretrain_base", "conformer_wav2vec2_pretrain_base",
......
...@@ -28,6 +28,8 @@ For such models, factory functions are provided. ...@@ -28,6 +28,8 @@ For such models, factory functions are provided.
HuBERTPretrainModel HuBERTPretrainModel
RNNT RNNT
RNNTBeamSearch RNNTBeamSearch
SquimObjective
SquimSubjective
Tacotron2 Tacotron2
Wav2Letter Wav2Letter
Wav2Vec2Model Wav2Vec2Model
......
...@@ -217,3 +217,53 @@ Pretrained Models ...@@ -217,3 +217,53 @@ Pretrained Models
CONVTASNET_BASE_LIBRI2MIX CONVTASNET_BASE_LIBRI2MIX
HDEMUCS_HIGH_MUSDB_PLUS HDEMUCS_HIGH_MUSDB_PLUS
HDEMUCS_HIGH_MUSDB HDEMUCS_HIGH_MUSDB
Squim Objective
---------------
Interface
~~~~~~~~~
:py:class:`SquimObjectiveBundle` defines speech quality and intelligibility measurement (SQUIM) pipeline that can predict **objecive** metric scores given the input waveform.
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_class.rst
SquimObjectiveBundle
Pretrained Models
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_data.rst
SQUIM_OBJECTIVE
Squim Subjective
----------------
Interface
~~~~~~~~~
:py:class:`SquimSubjectiveBundle` defines speech quality and intelligibility measurement (SQUIM) pipeline that can predict **subjective** metric scores given the input waveform.
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_class.rst
SquimSubjectiveBundle
Pretrained Models
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_data.rst
SQUIM_SUBJECTIVE
...@@ -23,8 +23,6 @@ For such models, factory functions are provided. ...@@ -23,8 +23,6 @@ For such models, factory functions are provided.
ConformerWav2Vec2PretrainModel ConformerWav2Vec2PretrainModel
ConvEmformer ConvEmformer
HiFiGANVocoder HiFiGANVocoder
SquimObjective
SquimSubjective
Prototype Factory Functions of Beta Models Prototype Factory Functions of Beta Models
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -45,53 +45,3 @@ Pretrained Models ...@@ -45,53 +45,3 @@ Pretrained Models
:template: autosummary/bundle_data.rst :template: autosummary/bundle_data.rst
HIFIGAN_VOCODER_V3_LJSPEECH HIFIGAN_VOCODER_V3_LJSPEECH
Squim Objective
---------------
Interface
~~~~~~~~~
:py:class:`SquimObjectiveBundle` defines speech quality and intelligibility measurement (SQUIM) pipeline that can predict **objecive** metric scores given the input waveform.
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_class.rst
SquimObjectiveBundle
Pretrained Models
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_data.rst
SQUIM_OBJECTIVE
Squim Subjective
----------------
Interface
~~~~~~~~~
:py:class:`SquimSubjectiveBundle` defines speech quality and intelligibility measurement (SQUIM) pipeline that can predict **subjective** metric scores given the input waveform.
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_class.rst
SquimSubjectiveBundle
Pretrained Models
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_data.rst
SQUIM_SUBJECTIVE
...@@ -80,7 +80,7 @@ print(torchaudio.__version__) ...@@ -80,7 +80,7 @@ print(torchaudio.__version__)
try: try:
from pesq import pesq from pesq import pesq
from pystoi import stoi from pystoi import stoi
from torchaudio.prototype.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE
except ImportError: except ImportError:
import google.colab # noqa: F401 import google.colab # noqa: F401
......
import pytest import pytest
import torchaudio import torchaudio
from torchaudio.prototype.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
import torch import torch
from parameterized import parameterized from parameterized import parameterized
from torchaudio.prototype.models import squim_objective_base, squim_subjective_base from torchaudio.models import squim_objective_base, squim_subjective_base
from torchaudio_unittest.common_utils import skipIfNoCuda, torch_script, TorchaudioTestCase from torchaudio_unittest.common_utils import skipIfNoCuda, torch_script, TorchaudioTestCase
......
...@@ -5,6 +5,14 @@ from .deepspeech import DeepSpeech ...@@ -5,6 +5,14 @@ from .deepspeech import DeepSpeech
from .emformer import Emformer from .emformer import Emformer
from .rnnt import emformer_rnnt_base, emformer_rnnt_model, RNNT from .rnnt import emformer_rnnt_base, emformer_rnnt_model, RNNT
from .rnnt_decoder import Hypothesis, RNNTBeamSearch from .rnnt_decoder import Hypothesis, RNNTBeamSearch
from .squim import (
squim_objective_base,
squim_objective_model,
squim_subjective_base,
squim_subjective_model,
SquimObjective,
SquimSubjective,
)
from .tacotron2 import Tacotron2 from .tacotron2 import Tacotron2
from .wav2letter import Wav2Letter from .wav2letter import Wav2Letter
from .wav2vec2 import ( from .wav2vec2 import (
...@@ -68,4 +76,10 @@ __all__ = [ ...@@ -68,4 +76,10 @@ __all__ = [
"hdemucs_low", "hdemucs_low",
"hdemucs_medium", "hdemucs_medium",
"hdemucs_high", "hdemucs_high",
"squim_objective_base",
"squim_objective_model",
"squim_subjective_base",
"squim_subjective_model",
"SquimObjective",
"SquimSubjective",
] ]
...@@ -4,6 +4,7 @@ from ._source_separation_pipeline import ( ...@@ -4,6 +4,7 @@ from ._source_separation_pipeline import (
HDEMUCS_HIGH_MUSDB_PLUS, HDEMUCS_HIGH_MUSDB_PLUS,
SourceSeparationBundle, SourceSeparationBundle,
) )
from ._squim_pipeline import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE, SquimObjectiveBundle, SquimSubjectiveBundle
from ._tts import ( from ._tts import (
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH, TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH, TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
...@@ -90,4 +91,8 @@ __all__ = [ ...@@ -90,4 +91,8 @@ __all__ = [
"CONVTASNET_BASE_LIBRI2MIX", "CONVTASNET_BASE_LIBRI2MIX",
"HDEMUCS_HIGH_MUSDB_PLUS", "HDEMUCS_HIGH_MUSDB_PLUS",
"HDEMUCS_HIGH_MUSDB", "HDEMUCS_HIGH_MUSDB",
"SQUIM_OBJECTIVE",
"SQUIM_SUBJECTIVE",
"SquimObjectiveBundle",
"SquimSubjectiveBundle",
] ]
...@@ -2,13 +2,13 @@ from dataclasses import dataclass ...@@ -2,13 +2,13 @@ from dataclasses import dataclass
from torchaudio._internal import load_state_dict_from_url from torchaudio._internal import load_state_dict_from_url
from torchaudio.prototype.models import squim_objective_base, squim_subjective_base, SquimObjective, SquimSubjective from torchaudio.models import squim_objective_base, squim_subjective_base, SquimObjective, SquimSubjective
@dataclass @dataclass
class SquimObjectiveBundle: class SquimObjectiveBundle:
"""Data class that bundles associated information to use pretrained """Data class that bundles associated information to use pretrained
:py:class:`~torchaudio.prototype.models.SquimObjective` model. :py:class:`~torchaudio.models.SquimObjective` model.
This class provides interfaces for instantiating the pretrained model along with This class provides interfaces for instantiating the pretrained model along with
the information necessary to retrieve pretrained weights and additional data the information necessary to retrieve pretrained weights and additional data
...@@ -24,8 +24,7 @@ class SquimObjectiveBundle: ...@@ -24,8 +24,7 @@ class SquimObjectiveBundle:
Example: Estimate the objective metric scores for the input waveform. Example: Estimate the objective metric scores for the input waveform.
>>> import torch >>> import torch
>>> import torchaudio >>> import torchaudio
>>> # Since SquimObjective bundle is in prototypes, it needs to be exported explicitly >>> from torchaudio.pipelines import SQUIM_OBJECTIVE as bundle
>>> from torchaudio.prototype.pipelines import SQUIM_OBJECTIVE as bundle
>>> >>>
>>> # Load the SquimObjective bundle >>> # Load the SquimObjective bundle
>>> model = bundle.get_model() >>> model = bundle.get_model()
...@@ -59,7 +58,7 @@ class SquimObjectiveBundle: ...@@ -59,7 +58,7 @@ class SquimObjectiveBundle:
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
Returns: Returns:
Variation of :py:class:`~torchaudio.prototype.models.SquimObjective`. Variation of :py:class:`~torchaudio.models.SquimObjective`.
""" """
model = squim_objective_base() model = squim_objective_base()
model.load_state_dict(self._get_state_dict(dl_kwargs)) model.load_state_dict(self._get_state_dict(dl_kwargs))
...@@ -82,7 +81,7 @@ SQUIM_OBJECTIVE = SquimObjectiveBundle( ...@@ -82,7 +81,7 @@ SQUIM_OBJECTIVE = SquimObjectiveBundle(
SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in
:cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`. :cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.
The underlying model is constructed by :py:func:`torchaudio.prototype.models.squim_objective_base`. The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.
The weights are under `Creative Commons Attribution 4.0 International License The weights are under `Creative Commons Attribution 4.0 International License
<https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE>`__. <https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE>`__.
...@@ -93,7 +92,7 @@ SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach desc ...@@ -93,7 +92,7 @@ SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach desc
@dataclass @dataclass
class SquimSubjectiveBundle: class SquimSubjectiveBundle:
"""Data class that bundles associated information to use pretrained """Data class that bundles associated information to use pretrained
:py:class:`~torchaudio.prototype.models.SquimSubjective` model. :py:class:`~torchaudio.models.SquimSubjective` model.
This class provides interfaces for instantiating the pretrained model along with This class provides interfaces for instantiating the pretrained model along with
the information necessary to retrieve pretrained weights and additional data the information necessary to retrieve pretrained weights and additional data
...@@ -109,8 +108,7 @@ class SquimSubjectiveBundle: ...@@ -109,8 +108,7 @@ class SquimSubjectiveBundle:
Example: Estimate the subjective metric scores for the input waveform. Example: Estimate the subjective metric scores for the input waveform.
>>> import torch >>> import torch
>>> import torchaudio >>> import torchaudio
>>> # Since SquimSubjective bundle is in prototypes, it needs to be exported explicitly >>> from torchaudio.pipelines import SQUIM_SUBJECTIVE as bundle
>>> from torchaudio.prototype.pipelines import SQUIM_SUBJECTIVE as bundle
>>> >>>
>>> # Load the SquimSubjective bundle >>> # Load the SquimSubjective bundle
>>> model = bundle.get_model() >>> model = bundle.get_model()
...@@ -146,7 +144,7 @@ class SquimSubjectiveBundle: ...@@ -146,7 +144,7 @@ class SquimSubjectiveBundle:
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
Returns: Returns:
Variation of :py:class:`~torchaudio.prototype.models.SquimObjective`. Variation of :py:class:`~torchaudio.models.SquimObjective`.
""" """
model = squim_subjective_base() model = squim_subjective_base()
model.load_state_dict(self._get_state_dict(dl_kwargs)) model.load_state_dict(self._get_state_dict(dl_kwargs))
...@@ -170,7 +168,7 @@ SQUIM_SUBJECTIVE.__doc__ = """SquimSubjective pipeline trained ...@@ -170,7 +168,7 @@ SQUIM_SUBJECTIVE.__doc__ = """SquimSubjective pipeline trained
as described in :cite:`manocha2022speech` and :cite:`kumar2023torchaudio` as described in :cite:`manocha2022speech` and :cite:`kumar2023torchaudio`
on the *BVCC* :cite:`cooper2021voices` and *DAPS* :cite:`mysore2014can` datasets. on the *BVCC* :cite:`cooper2021voices` and *DAPS* :cite:`mysore2014can` datasets.
The underlying model is constructed by :py:func:`torchaudio.prototype.models.squim_subjective_base`. The underlying model is constructed by :py:func:`torchaudio.models.squim_subjective_base`.
The weights are under `Creative Commons Attribution Non Commercial 4.0 International The weights are under `Creative Commons Attribution Non Commercial 4.0 International
<https://zenodo.org/record/4660670#.ZBtWPOxuerN>`__. <https://zenodo.org/record/4660670#.ZBtWPOxuerN>`__.
......
...@@ -11,14 +11,6 @@ from .conv_emformer import ConvEmformer ...@@ -11,14 +11,6 @@ from .conv_emformer import ConvEmformer
from .hifi_gan import hifigan_vocoder, hifigan_vocoder_v1, hifigan_vocoder_v2, hifigan_vocoder_v3, HiFiGANVocoder from .hifi_gan import hifigan_vocoder, hifigan_vocoder_v1, hifigan_vocoder_v2, hifigan_vocoder_v3, HiFiGANVocoder
from .rnnt import conformer_rnnt_base, conformer_rnnt_biasing, conformer_rnnt_biasing_base, conformer_rnnt_model from .rnnt import conformer_rnnt_base, conformer_rnnt_biasing, conformer_rnnt_biasing_base, conformer_rnnt_model
from .rnnt_decoder import Hypothesis, RNNTBeamSearchBiasing from .rnnt_decoder import Hypothesis, RNNTBeamSearchBiasing
from .squim import (
squim_objective_base,
squim_objective_model,
squim_subjective_base,
squim_subjective_model,
SquimObjective,
SquimSubjective,
)
__all__ = [ __all__ = [
"conformer_rnnt_base", "conformer_rnnt_base",
...@@ -42,10 +34,4 @@ __all__ = [ ...@@ -42,10 +34,4 @@ __all__ = [
"hifigan_vocoder_v2", "hifigan_vocoder_v2",
"hifigan_vocoder_v3", "hifigan_vocoder_v3",
"hifigan_vocoder", "hifigan_vocoder",
"squim_objective_base",
"squim_objective_model",
"squim_subjective_base",
"squim_subjective_model",
"SquimObjective",
"SquimSubjective",
] ]
from .hifigan_pipeline import HIFIGAN_VOCODER_V3_LJSPEECH, HiFiGANVocoderBundle from .hifigan_pipeline import HIFIGAN_VOCODER_V3_LJSPEECH, HiFiGANVocoderBundle
from .rnnt_pipeline import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3 from .rnnt_pipeline import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
from .squim_pipeline import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE, SquimObjectiveBundle, SquimSubjectiveBundle
__all__ = [ __all__ = [
"EMFORMER_RNNT_BASE_MUSTC", "EMFORMER_RNNT_BASE_MUSTC",
"EMFORMER_RNNT_BASE_TEDLIUM3", "EMFORMER_RNNT_BASE_TEDLIUM3",
"HIFIGAN_VOCODER_V3_LJSPEECH", "HIFIGAN_VOCODER_V3_LJSPEECH",
"HiFiGANVocoderBundle", "HiFiGANVocoderBundle",
"SQUIM_OBJECTIVE",
"SQUIM_SUBJECTIVE",
"SquimObjectiveBundle",
"SquimSubjectiveBundle",
] ]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment