Commit a8a16238 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add SquimSubjective Model (#3189)

Summary:
Add model architecture and factory functions for `SquimSubjective` which predicts subjective evaluation metric scores (e.g. MOS) for speech enhancement task.

Pull Request resolved: https://github.com/pytorch/audio/pull/3189

Reviewed By: mthrok

Differential Revision: D44267255

Pulled By: nateanl

fbshipit-source-id: f8060398b14c625b38ea1bb2417f61aeaec3f1db
parent f8d8ffb5
...@@ -17,6 +17,10 @@ ...@@ -17,6 +17,10 @@
"squim_objective_model", "squim_objective_model",
"squim_objective_base", "squim_objective_base",
], ],
"torchaudio.prototype.models.SquimSubjective": [
"squim_subjective_model",
"squim_subjective_base",
],
"torchaudio.prototype.models.ConformerWav2Vec2PretrainModel": [ "torchaudio.prototype.models.ConformerWav2Vec2PretrainModel": [
"conformer_wav2vec2_pretrain_model", "conformer_wav2vec2_pretrain_model",
"conformer_wav2vec2_pretrain_base", "conformer_wav2vec2_pretrain_base",
......
...@@ -24,6 +24,7 @@ For such models, factory functions are provided. ...@@ -24,6 +24,7 @@ For such models, factory functions are provided.
ConvEmformer ConvEmformer
HiFiGANVocoder HiFiGANVocoder
SquimObjective SquimObjective
SquimSubjective
Prototype Factory Functions of Beta Models Prototype Factory Functions of Beta Models
========================================== ==========================================
......
...@@ -534,3 +534,9 @@ abstract = {End-to-end spoken language translation (SLT) has recently gained pop ...@@ -534,3 +534,9 @@ abstract = {End-to-end spoken language translation (SLT) has recently gained pop
journal={arXiv preprint arXiv:2005.13981}, journal={arXiv preprint arXiv:2005.13981},
year={2020} year={2020}
} }
@article{manocha2022speech,
title={Speech quality assessment through MOS using non-matching references},
author={Manocha, Pranay and Kumar, Anurag},
journal={arXiv preprint arXiv:2206.12285},
year={2022}
}
import torch import torch
from parameterized import parameterized from parameterized import parameterized
from torchaudio.prototype.models import squim_objective_base from torchaudio.prototype.models import squim_objective_base, squim_subjective_base
from torchaudio_unittest.common_utils import skipIfNoCuda, torch_script, TorchaudioTestCase from torchaudio_unittest.common_utils import skipIfNoCuda, torch_script, TorchaudioTestCase
class TestSQUIM(TorchaudioTestCase): class TestSquimObjective(TorchaudioTestCase):
def _smoke_test_objective(self, model, device, dtype): def _smoke_test_objective(self, model, device, dtype):
model = model.to(device=device, dtype=dtype) model = model.to(device=device, dtype=dtype)
model = model.eval() model = model.eval()
...@@ -57,3 +57,57 @@ class TestSQUIM(TorchaudioTestCase): ...@@ -57,3 +57,57 @@ class TestSQUIM(TorchaudioTestCase):
self.assertEqual(len(hyp_scores), len(ref_scores)) self.assertEqual(len(hyp_scores), len(ref_scores))
for i in range(len(ref_scores)): for i in range(len(ref_scores)):
self.assertEqual(hyp_scores[i], ref_scores[i]) self.assertEqual(hyp_scores[i], ref_scores[i])
class TestSquimSubjective(TorchaudioTestCase):
def _smoke_test_subjective(self, model, device, dtype):
model = model.to(device=device, dtype=dtype)
model = model.eval()
batch_size, num_frames = 3, 16000
waveforms = torch.randn(batch_size, num_frames, device=device, dtype=dtype)
reference = torch.randn(batch_size, num_frames, device=device, dtype=dtype)
model(waveforms, reference)
@parameterized.expand([(torch.float32,), (torch.float64,)])
def test_cpu_smoke_test(self, dtype):
model = squim_subjective_base()
self._smoke_test_subjective(model, torch.device("cpu"), dtype)
@parameterized.expand([(torch.float32,), (torch.float64,)])
@skipIfNoCuda
def test_cuda_smoke_test(self, dtype):
model = squim_subjective_base()
self._smoke_test_subjective(model, torch.device("cuda"), dtype)
def test_batch_consistency(self):
model = squim_subjective_base()
model.eval()
batch_size, num_frames = 3, 16000
waveforms = torch.randn(batch_size, num_frames)
reference = torch.randn(batch_size, num_frames)
ref_scores = model(waveforms, reference)
hyp_scores = []
for i in range(batch_size):
scores = model(waveforms[i : i + 1], reference[i : i + 1])
hyp_scores.append(scores)
hyp_scores = torch.tensor(hyp_scores)
self.assertEqual(hyp_scores, ref_scores)
def test_torchscript_consistency(self):
model = squim_subjective_base()
model.eval()
batch_size, num_frames = 3, 16000
waveforms = torch.randn(batch_size, num_frames)
reference = torch.randn(batch_size, num_frames)
ref_scores = model(waveforms, reference)
scripted = torch_script(model)
hyp_scores = scripted(waveforms, reference)
self.assertEqual(hyp_scores, ref_scores)
...@@ -11,7 +11,14 @@ from .conv_emformer import ConvEmformer ...@@ -11,7 +11,14 @@ from .conv_emformer import ConvEmformer
from .hifi_gan import hifigan_vocoder, hifigan_vocoder_v1, hifigan_vocoder_v2, hifigan_vocoder_v3, HiFiGANVocoder from .hifi_gan import hifigan_vocoder, hifigan_vocoder_v1, hifigan_vocoder_v2, hifigan_vocoder_v3, HiFiGANVocoder
from .rnnt import conformer_rnnt_base, conformer_rnnt_biasing, conformer_rnnt_biasing_base, conformer_rnnt_model from .rnnt import conformer_rnnt_base, conformer_rnnt_biasing, conformer_rnnt_biasing_base, conformer_rnnt_model
from .rnnt_decoder import Hypothesis, RNNTBeamSearchBiasing from .rnnt_decoder import Hypothesis, RNNTBeamSearchBiasing
from .squim import squim_objective_base, squim_objective_model, SquimObjective from .squim import (
squim_objective_base,
squim_objective_model,
squim_subjective_base,
squim_subjective_model,
SquimObjective,
SquimSubjective,
)
__all__ = [ __all__ = [
"conformer_rnnt_base", "conformer_rnnt_base",
...@@ -37,5 +44,8 @@ __all__ = [ ...@@ -37,5 +44,8 @@ __all__ = [
"hifigan_vocoder", "hifigan_vocoder",
"squim_objective_base", "squim_objective_base",
"squim_objective_model", "squim_objective_model",
"squim_subjective_base",
"squim_subjective_model",
"SquimObjective", "SquimObjective",
"SquimSubjective",
] ]
from .objective import squim_objective_base, squim_objective_model, SquimObjective from .objective import squim_objective_base, squim_objective_model, SquimObjective
from .subjective import squim_subjective_base, squim_subjective_model, SquimSubjective
__all__ = [ __all__ = [
"squim_objective_base", "squim_objective_base",
"squim_objective_model", "squim_objective_model",
"squim_subjective_base",
"squim_subjective_model",
"SquimObjective", "SquimObjective",
"SquimSubjective",
] ]
from typing import Tuple
import torch
import torch.nn as nn
import torchaudio
class AttPool(nn.Module):
"""Attention-Pooling module that estimates the attention score.
Args:
input_dim (int): Input feature dimension.
att_dim (int): Attention Tensor dimension.
"""
def __init__(self, input_dim: int, att_dim: int):
super(AttPool, self).__init__()
self.linear1 = nn.Linear(input_dim, 1)
self.linear2 = nn.Linear(input_dim, att_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply attention and pooling.
Args:
x (torch.Tensor): Input Tensor with dimensions `(batch, time, feature_dim)`.
Returns:
(torch.Tensor): Attention score with dimensions `(batch, att_dim)`.
"""
att = self.linear1(x) # (batch, time, 1)
att = att.transpose(2, 1) # (batch, 1, time)
att = nn.functional.softmax(att, dim=2)
x = torch.matmul(att, x).squeeze(1) # (batch, input_dim)
x = self.linear2(x) # (batch, att_dim)
return x
class Predictor(nn.Module):
"""Prediction module that apply pooling and attention, then predict subjective metric scores.
Args:
input_dim (int): Input feature dimension.
att_dim (int): Attention Tensor dimension.
"""
def __init__(self, input_dim: int, att_dim: int):
super(Predictor, self).__init__()
self.att_pool_layer = AttPool(input_dim, att_dim)
self.att_dim = att_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Predict subjective evaluation metric score.
Args:
x (torch.Tensor): Input Tensor with dimensions `(batch, time, feature_dim)`.
Returns:
(torch.Tensor): Subjective metric score. Tensor with dimensions `(batch,)`.
"""
x = self.att_pool_layer(x)
x = nn.functional.softmax(x, dim=1)
B = torch.linspace(0, 4, steps=self.att_dim, device=x.device)
x = (x * B).sum(dim=1)
return x
class SquimSubjective(nn.Module):
"""Speech Quality and Intelligibility Measures (SQUIM) model that predicts **subjective** metric scores
for speech enhancement (e.g., Mean Opinion Score (MOS)). The model is adopted from *NORESQA-MOS*
:cite:`manocha2022speech` which predicts MOS scores given the input speech and a non-matching reference.
Args:
ssl_model (torch.nn.Module): The self-supervised learning model for feature extraction.
projector (torch.nn.Module): Projection layer that projects SSL feature to a lower dimension.
predictor (torch.nn.Module): Predict the subjective scores.
"""
def __init__(self, ssl_model: nn.Module, projector: nn.Module, predictor: nn.Module):
super(SquimSubjective, self).__init__()
self.ssl_model = ssl_model
self.projector = projector
self.predictor = predictor
def _align_shapes(self, waveform: torch.Tensor, reference: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Cut or pad the reference Tensor to make it aligned with waveform Tensor.
Args:
waveform (torch.Tensor): Input waveform for evaluation. Tensor with dimensions `(batch, time)`.
reference (torch.Tensor): Non-matching clean reference. Tensor with dimensions `(batch, time_ref)`.
Returns:
(torch.Tensor, torch.Tensor): The aligned waveform and reference Tensors
with same dimensions `(batch, time)`.
"""
T_waveform = waveform.shape[-1]
T_reference = reference.shape[-1]
if T_reference < T_waveform:
num_padding = T_waveform // T_reference + 1
reference = torch.cat([reference for _ in range(num_padding)], dim=1)
return waveform, reference[:, :T_waveform]
def forward(self, waveform: torch.Tensor, reference: torch.Tensor):
"""Predict subjective evaluation metric score.
Args:
waveform (torch.Tensor): Input waveform for evaluation. Tensor with dimensions `(batch, time)`.
reference (torch.Tensor): Non-matching clean reference. Tensor with dimensions `(batch, time_ref)`.
Returns:
(torch.Tensor): Subjective metric score. Tensor with dimensions `(batch,)`.
"""
waveform, reference = self._align_shapes(waveform, reference)
waveform = self.projector(self.ssl_model.extract_features(waveform)[0][-1])
reference = self.projector(self.ssl_model.extract_features(reference)[0][-1])
concat = torch.cat((reference, waveform), dim=2)
score_diff = self.predictor(concat) # Score difference compared to the reference
return 5 - score_diff
def squim_subjective_model(
ssl_type: str,
feat_dim: int,
proj_dim: int,
att_dim: int,
) -> SquimSubjective:
"""Build a custome :class:`torchaudio.prototype.models.SquimSubjective` model.
Args:
ssl_type (str): Type of self-supervised learning (SSL) models.
Must be one of ["wav2vec2_base", "wav2vec2_large"].
feature_dim (int): Feature dimension of the SSL feature representation.
proj_dim (int): Output dimension of projection layer.
att_dim (int): Dimension of attention scores.
"""
ssl_model = getattr(torchaudio.models, ssl_type)()
projector = nn.Linear(feat_dim, proj_dim)
predictor = Predictor(proj_dim * 2, att_dim)
return SquimSubjective(ssl_model, projector, predictor)
def squim_subjective_base() -> SquimSubjective:
"""Build :class:`torchaudio.prototype.models.SquimSubjective` model with default arguments."""
return squim_subjective_model(
ssl_type="wav2vec2_base",
feat_dim=768,
proj_dim=32,
att_dim=5,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment