Commit b0155938 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Rename SQUIM_OBJECTIVE model to SquimObjective (#3087)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/3087

Reviewed By: xiaohui-zhang, mthrok

Differential Revision: D43509865

Pulled By: nateanl

fbshipit-source-id: 569cc2ee8edd9de0b7d255a1e1075ac812b26cc8
parent b35a5fcf
...@@ -10,7 +10,7 @@ from ._emformer_hubert import emformer_hubert_base, emformer_hubert_model ...@@ -10,7 +10,7 @@ from ._emformer_hubert import emformer_hubert_base, emformer_hubert_model
from .conv_emformer import ConvEmformer from .conv_emformer import ConvEmformer
from .hifi_gan import hifigan_vocoder, hifigan_vocoder_v1, hifigan_vocoder_v2, hifigan_vocoder_v3, HiFiGANVocoder from .hifi_gan import hifigan_vocoder, hifigan_vocoder_v1, hifigan_vocoder_v2, hifigan_vocoder_v3, HiFiGANVocoder
from .rnnt import conformer_rnnt_base, conformer_rnnt_model from .rnnt import conformer_rnnt_base, conformer_rnnt_model
from .squim import SQUIM_OBJECTIVE, squim_objective_base, squim_objective_model from .squim import squim_objective_base, squim_objective_model, SquimObjective
__all__ = [ __all__ = [
"conformer_rnnt_base", "conformer_rnnt_base",
...@@ -31,5 +31,5 @@ __all__ = [ ...@@ -31,5 +31,5 @@ __all__ = [
"hifigan_vocoder", "hifigan_vocoder",
"squim_objective_base", "squim_objective_base",
"squim_objective_model", "squim_objective_model",
"SQUIM_OBJECTIVE", "SquimObjective",
] ]
from .objective import SQUIM_OBJECTIVE, squim_objective_base, squim_objective_model from .objective import squim_objective_base, squim_objective_model, SquimObjective
__all__ = [ __all__ = [
"squim_objective_base", "squim_objective_base",
"squim_objective_model", "squim_objective_model",
"SQUIM_OBJECTIVE", "SquimObjective",
] ]
...@@ -202,8 +202,9 @@ class AutoPool(nn.Module): ...@@ -202,8 +202,9 @@ class AutoPool(nn.Module):
return out return out
class SQUIM_OBJECTIVE(nn.Module): class SquimObjective(nn.Module):
"""SQUIM_OBJECTIVE model that predicts objective metric scorres for speech enhancement (e.g., STOI, PESQ, and SI-SDR). """Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores
for speech enhancement (e.g., STOI, PESQ, and SI-SDR).
The model uses *dual-path recurrent neural networks (DPRNN)* :cite:`luo2020dual` to model sequential signals, The model uses *dual-path recurrent neural networks (DPRNN)* :cite:`luo2020dual` to model sequential signals,
and multiple transformer branches to estimate the objective metric scores, respectively. and multiple transformer branches to estimate the objective metric scores, respectively.
...@@ -219,7 +220,7 @@ class SQUIM_OBJECTIVE(nn.Module): ...@@ -219,7 +220,7 @@ class SQUIM_OBJECTIVE(nn.Module):
dprnn: nn.Module, dprnn: nn.Module,
branches: nn.ModuleList, branches: nn.ModuleList,
): ):
super(SQUIM_OBJECTIVE, self).__init__() super(SquimObjective, self).__init__()
self.encoder = encoder self.encoder = encoder
self.dprnn = dprnn self.dprnn = dprnn
self.branches = branches self.branches = branches
...@@ -285,8 +286,8 @@ def squim_objective_model( ...@@ -285,8 +286,8 @@ def squim_objective_model(
rnn_type: str, rnn_type: str,
chunk_size: int, chunk_size: int,
chunk_stride: Optional[int] = None, chunk_stride: Optional[int] = None,
) -> SQUIM_OBJECTIVE: ) -> SquimObjective:
"""Build a custome :class:`torchaudio.prototype.models.SQUIM_OBJECTIVE` model. """Build a custome :class:`torchaudio.prototype.models.SquimObjective` model.
Args: Args:
feat_dim (int, optional): The feature dimension after Encoder module. feat_dim (int, optional): The feature dimension after Encoder module.
...@@ -310,11 +311,11 @@ def squim_objective_model( ...@@ -310,11 +311,11 @@ def squim_objective_model(
_create_branch(d_model, nhead, "sisdr"), _create_branch(d_model, nhead, "sisdr"),
] ]
) )
return SQUIM_OBJECTIVE(encoder, dprnn, branches) return SquimObjective(encoder, dprnn, branches)
def squim_objective_base() -> SQUIM_OBJECTIVE: def squim_objective_base() -> SquimObjective:
"""Build :class:`torchaudio.prototype.models.SQUIM_OBJECTIVE` model with default arguments.""" """Build :class:`torchaudio.prototype.models.SquimObjective` model with default arguments."""
return squim_objective_model( return squim_objective_model(
feat_dim=256, feat_dim=256,
win_len=64, win_len=64,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment