Commit b0155938 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Rename SQUIM_OBJECTIVE model to SquimObjective (#3087)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/3087

Reviewed By: xiaohui-zhang, mthrok

Differential Revision: D43509865

Pulled By: nateanl

fbshipit-source-id: 569cc2ee8edd9de0b7d255a1e1075ac812b26cc8
parent b35a5fcf
......@@ -10,7 +10,7 @@ from ._emformer_hubert import emformer_hubert_base, emformer_hubert_model
from .conv_emformer import ConvEmformer
from .hifi_gan import hifigan_vocoder, hifigan_vocoder_v1, hifigan_vocoder_v2, hifigan_vocoder_v3, HiFiGANVocoder
from .rnnt import conformer_rnnt_base, conformer_rnnt_model
from .squim import SQUIM_OBJECTIVE, squim_objective_base, squim_objective_model
from .squim import squim_objective_base, squim_objective_model, SquimObjective
__all__ = [
"conformer_rnnt_base",
......@@ -31,5 +31,5 @@ __all__ = [
"hifigan_vocoder",
"squim_objective_base",
"squim_objective_model",
"SQUIM_OBJECTIVE",
"SquimObjective",
]
from .objective import SQUIM_OBJECTIVE, squim_objective_base, squim_objective_model
from .objective import squim_objective_base, squim_objective_model, SquimObjective
__all__ = [
"squim_objective_base",
"squim_objective_model",
"SQUIM_OBJECTIVE",
"SquimObjective",
]
......@@ -202,8 +202,9 @@ class AutoPool(nn.Module):
return out
class SQUIM_OBJECTIVE(nn.Module):
"""SQUIM_OBJECTIVE model that predicts objective metric scorres for speech enhancement (e.g., STOI, PESQ, and SI-SDR).
class SquimObjective(nn.Module):
"""Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores
for speech enhancement (e.g., STOI, PESQ, and SI-SDR).
The model uses *dual-path recurrent neural networks (DPRNN)* :cite:`luo2020dual` to model sequential signals,
and multiple transformer branches to estimate the objective metric scores, respectively.
......@@ -219,7 +220,7 @@ class SQUIM_OBJECTIVE(nn.Module):
dprnn: nn.Module,
branches: nn.ModuleList,
):
super(SQUIM_OBJECTIVE, self).__init__()
super(SquimObjective, self).__init__()
self.encoder = encoder
self.dprnn = dprnn
self.branches = branches
......@@ -285,8 +286,8 @@ def squim_objective_model(
rnn_type: str,
chunk_size: int,
chunk_stride: Optional[int] = None,
) -> SQUIM_OBJECTIVE:
"""Build a custome :class:`torchaudio.prototype.models.SQUIM_OBJECTIVE` model.
) -> SquimObjective:
"""Build a custome :class:`torchaudio.prototype.models.SquimObjective` model.
Args:
feat_dim (int, optional): The feature dimension after Encoder module.
......@@ -310,11 +311,11 @@ def squim_objective_model(
_create_branch(d_model, nhead, "sisdr"),
]
)
return SQUIM_OBJECTIVE(encoder, dprnn, branches)
return SquimObjective(encoder, dprnn, branches)
def squim_objective_base() -> SQUIM_OBJECTIVE:
"""Build :class:`torchaudio.prototype.models.SQUIM_OBJECTIVE` model with default arguments."""
def squim_objective_base() -> SquimObjective:
"""Build :class:`torchaudio.prototype.models.SquimObjective` model with default arguments."""
return squim_objective_model(
feat_dim=256,
win_len=64,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment