Commit b0795ebe authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add aux_num_out to emformer_hubert_model (#2868)

Summary:
- layer_norm in `EmformerEncoder` is set as default in emformer_hubert_model, change the type to be non-optional.
- add `aux_num_out` to emformer_hubert_model to support fine-tuning model.
- update unit tests.

Pull Request resolved: https://github.com/pytorch/audio/pull/2868

Reviewed By: carolineechen

Differential Revision: D41451311

Pulled By: nateanl

fbshipit-source-id: 5fa0f19255e4f01e001d62f8689e36f134030083
parent 52e89756
import torch import torch
from parameterized import parameterized from parameterized import parameterized
from torchaudio.prototype.models import conformer_wav2vec2_base, emformer_hubert_base from torchaudio.prototype.models import conformer_wav2vec2_base, emformer_hubert_base
from torchaudio_unittest.common_utils import skipIfNoCuda, torch_script, TorchaudioTestCase from torchaudio_unittest.common_utils import nested_params, skipIfNoCuda, torch_script, TorchaudioTestCase
class TestSSLModel(TorchaudioTestCase): class TestSSLModel(TorchaudioTestCase):
...@@ -22,39 +22,37 @@ class TestSSLModel(TorchaudioTestCase): ...@@ -22,39 +22,37 @@ class TestSSLModel(TorchaudioTestCase):
model(features, lengths) model(features, lengths)
@parameterized.expand( @nested_params(
[ [(conformer_wav2vec2_base, 64), (emformer_hubert_base, 80)],
(conformer_wav2vec2_base, torch.float32, 64), [torch.float32, torch.float64],
(conformer_wav2vec2_base, torch.float64, 64),
(emformer_hubert_base, torch.float32, 80),
(emformer_hubert_base, torch.float64, 80),
]
) )
def test_cpu_smoke_test(self, model, dtype, feature_dim): def test_cpu_smoke_test(self, model_feature_dim, dtype):
model, feature_dim = model_feature_dim
model = model() model = model()
self._smoke_test(model, feature_dim, torch.device("cpu"), dtype) self._smoke_test(model, feature_dim, torch.device("cpu"), dtype)
@parameterized.expand( @nested_params(
[ [(conformer_wav2vec2_base, 64), (emformer_hubert_base, 80)],
(conformer_wav2vec2_base, torch.float32, 64), [torch.float32, torch.float64],
(conformer_wav2vec2_base, torch.float64, 64),
(emformer_hubert_base, torch.float32, 80),
(emformer_hubert_base, torch.float64, 80),
]
) )
@skipIfNoCuda @skipIfNoCuda
def test_cuda_smoke_test(self, model, dtype, feature_dim): def test_cuda_smoke_test(self, model_feature_dim, dtype):
model, feature_dim = model_feature_dim
model = model() model = model()
self._smoke_test(model, feature_dim, torch.device("cuda"), dtype) self._smoke_test(model, feature_dim, torch.device("cuda"), dtype)
@parameterized.expand( @parameterized.expand(
[ [
(conformer_wav2vec2_base, 64), (conformer_wav2vec2_base, 64, None),
(emformer_hubert_base, 80), (emformer_hubert_base, 80, None),
(emformer_hubert_base, 80, 512),
] ]
) )
def test_extract_feature(self, model, feature_dim): def test_extract_feature(self, model, feature_dim, aux_num_out):
model = model() if aux_num_out is not None:
model = model(aux_num_out=aux_num_out)
else:
model = model()
model.eval() model.eval()
batch_size, num_frames = 3, 1024 batch_size, num_frames = 3, 1024
...@@ -88,12 +86,16 @@ class TestSSLModel(TorchaudioTestCase): ...@@ -88,12 +86,16 @@ class TestSSLModel(TorchaudioTestCase):
@parameterized.expand( @parameterized.expand(
[ [
(conformer_wav2vec2_base, 64), (conformer_wav2vec2_base, 64, None),
(emformer_hubert_base, 80), (emformer_hubert_base, 80, None),
(emformer_hubert_base, 80, 512),
] ]
) )
def test_zero_length(self, model, feature_dim): def test_zero_length(self, model, feature_dim, aux_num_out):
model = model() if aux_num_out is not None:
model = model(aux_num_out=aux_num_out)
else:
model = model()
model.eval() model.eval()
batch_size, num_frames = 3, 1024 batch_size, num_frames = 3, 1024
...@@ -107,12 +109,16 @@ class TestSSLModel(TorchaudioTestCase): ...@@ -107,12 +109,16 @@ class TestSSLModel(TorchaudioTestCase):
@parameterized.expand( @parameterized.expand(
[ [
(conformer_wav2vec2_base, 64), (conformer_wav2vec2_base, 64, None),
(emformer_hubert_base, 80), (emformer_hubert_base, 80, None),
(emformer_hubert_base, 80, 512),
] ]
) )
def test_torchscript_consistency(self, model, feature_dim): def test_torchscript_consistency(self, model, feature_dim, aux_num_out):
model = model() if aux_num_out is not None:
model = model(aux_num_out=aux_num_out)
else:
model = model()
model.eval() model.eval()
batch_size, num_frames = 3, 1024 batch_size, num_frames = 3, 1024
......
...@@ -58,15 +58,15 @@ class EmformerEncoder(torch.nn.Module): ...@@ -58,15 +58,15 @@ class EmformerEncoder(torch.nn.Module):
:py:class:`torchaudio.models.Emformer` module that consists of a list of emformer layers. :py:class:`torchaudio.models.Emformer` module that consists of a list of emformer layers.
output_linear (torch.nn.Module): output_linear (torch.nn.Module):
Linear layer after emformer module. Linear layer after emformer module.
layer_norm (torch.nn.Module or None, optional): layer_norm (torch.nn.Module):
If ``None``, don't apply layer normalization to the output. Apply layer normalization to the output.
""" """
def __init__( def __init__(
self, self,
emformer: torch.nn.Module, emformer: torch.nn.Module,
output_linear: torch.nn.Module, output_linear: torch.nn.Module,
layer_norm: Optional[torch.nn.Module] = None, layer_norm: torch.nn.Module,
): ):
super().__init__() super().__init__()
self.emformer = emformer self.emformer = emformer
...@@ -95,8 +95,7 @@ class EmformerEncoder(torch.nn.Module): ...@@ -95,8 +95,7 @@ class EmformerEncoder(torch.nn.Module):
else: else:
output, lengths = self.emformer(input, lengths) output, lengths = self.emformer(input, lengths)
output = self.output_linear(output) output = self.output_linear(output)
if self.layer_norm is not None: output = self.layer_norm(output)
output = self.layer_norm(output)
return output return output
def extract_features( def extract_features(
...@@ -214,7 +213,8 @@ def _get_emformer_encoder( ...@@ -214,7 +213,8 @@ def _get_emformer_encoder(
tanh_on_mem=tanh_on_mem, tanh_on_mem=tanh_on_mem,
) )
output_linear = torch.nn.Linear(input_dim, output_dim) output_linear = torch.nn.Linear(input_dim, output_dim)
return EmformerEncoder(emformer, output_linear) layer_norm = torch.nn.LayerNorm(output_dim)
return EmformerEncoder(emformer, output_linear, layer_norm)
def emformer_hubert_model( def emformer_hubert_model(
...@@ -235,6 +235,7 @@ def emformer_hubert_model( ...@@ -235,6 +235,7 @@ def emformer_hubert_model(
encoder_max_memory_size: int, encoder_max_memory_size: int,
encoder_weight_init_scale_strategy: Optional[str], encoder_weight_init_scale_strategy: Optional[str],
encoder_tanh_on_mem: bool, encoder_tanh_on_mem: bool,
aux_num_out: Optional[int],
) -> Wav2Vec2Model: ) -> Wav2Vec2Model:
"""Build a custom Emformer HuBERT model. """Build a custom Emformer HuBERT model.
...@@ -258,6 +259,9 @@ def emformer_hubert_model( ...@@ -258,6 +259,9 @@ def emformer_hubert_model(
encoder_weight_init_scale_strategy (str or None): Per-layer weight initialization scaling encoder_weight_init_scale_strategy (str or None): Per-layer weight initialization scaling
strategy. Must be one of ("depthwise", "constant", ``None``). strategy. Must be one of ("depthwise", "constant", ``None``).
encoder_tanh_on_mem (bool): If ``True``, applies tanh to memory elements. encoder_tanh_on_mem (bool): If ``True``, applies tanh to memory elements.
aux_num_out (int or None):
When provided, attach an extra linear layer on top of encoder, which can be
used for fine-tuning.
Returns: Returns:
Wav2Vec2Model: Wav2Vec2Model:
...@@ -282,13 +286,17 @@ def emformer_hubert_model( ...@@ -282,13 +286,17 @@ def emformer_hubert_model(
encoder_weight_init_scale_strategy, encoder_weight_init_scale_strategy,
encoder_tanh_on_mem, encoder_tanh_on_mem,
) )
return Wav2Vec2Model(feature_extractor, emformer) aux = None
if aux_num_out is not None:
aux = torch.nn.Linear(in_features=encoder_output_dim, out_features=aux_num_out)
return Wav2Vec2Model(feature_extractor, emformer, aux)
def emformer_hubert_base( def emformer_hubert_base(
extractor_input_dim: int = 80, extractor_input_dim: int = 80,
extractor_output_dim: int = 128, extractor_output_dim: int = 128,
encoder_dropout: float = 0.1, encoder_dropout: float = 0.1,
aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model: ) -> Wav2Vec2Model:
"""Build Emformer HuBERT Model with 20 Emformer layers. """Build Emformer HuBERT Model with 20 Emformer layers.
...@@ -296,6 +304,7 @@ def emformer_hubert_base( ...@@ -296,6 +304,7 @@ def emformer_hubert_base(
extractor_input_dim (int, optional): The input dimension for feature extractor. (Default: 80) extractor_input_dim (int, optional): The input dimension for feature extractor. (Default: 80)
extractor_output_dim (int, optional): The output dimension after feature extractor. (Default: 128) extractor_output_dim (int, optional): The output dimension after feature extractor. (Default: 128)
encoder_dropout (float, optional): Dropout probability in Emformer. (Default: 0.1) encoder_dropout (float, optional): Dropout probability in Emformer. (Default: 0.1)
aux_num_out (int or None, optional): Output dimension of aux layer for fine-tuning. (Default: ``None``)
Returns: Returns:
Wav2Vec2Model: Wav2Vec2Model:
...@@ -320,4 +329,5 @@ def emformer_hubert_base( ...@@ -320,4 +329,5 @@ def emformer_hubert_base(
encoder_max_memory_size=0, encoder_max_memory_size=0,
encoder_weight_init_scale_strategy="depthwise", encoder_weight_init_scale_strategy="depthwise",
encoder_tanh_on_mem=True, encoder_tanh_on_mem=True,
aux_num_out=aux_num_out,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment