Commit 92b6847e authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add emformer hubert model architecture (#2836)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2836

Reviewed By: carolineechen

Differential Revision: D41208630

Pulled By: nateanl

fbshipit-source-id: 625e1651f0b8a6e20876409739cf7084cb7c748b
parent 13063f9b
...@@ -14,6 +14,16 @@ conformer_rnnt_base ...@@ -14,6 +14,16 @@ conformer_rnnt_base
.. autofunction:: conformer_rnnt_base .. autofunction:: conformer_rnnt_base
emformer_hubert_model
~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: emformer_hubert_model
emformer_hubert_base
~~~~~~~~~~~~~~~~~~~~
.. autofunction:: emformer_hubert_base
ConvEmformer ConvEmformer
~~~~~~~~~~~~ ~~~~~~~~~~~~
......
import torch import torch
from parameterized import parameterized from parameterized import parameterized
from torchaudio.prototype.models import conformer_wav2vec2_base from torchaudio.prototype.models import conformer_wav2vec2_base, emformer_hubert_base
from torchaudio_unittest.common_utils import skipIfNoCuda, torch_script, TorchaudioTestCase from torchaudio_unittest.common_utils import skipIfNoCuda, torch_script, TorchaudioTestCase
class TestConformerWav2Vec2(TorchaudioTestCase): class TestSSLModel(TorchaudioTestCase):
def _smoke_test(self, model, device, dtype): def _smoke_test(self, model, feature_dim, device, dtype):
model = model.to(device=device, dtype=dtype) model = model.to(device=device, dtype=dtype)
model = model.eval() model = model.eval()
batch_size, num_frames, in_features = 3, 1024, 64 batch_size, num_frames = 3, 1024
features = torch.randn(batch_size, num_frames, in_features, device=device, dtype=dtype) features = torch.randn(batch_size, num_frames, feature_dim, device=device, dtype=dtype)
lengths = torch.randint( lengths = torch.randint(
low=0, low=0,
high=num_frames, high=num_frames,
...@@ -22,25 +22,48 @@ class TestConformerWav2Vec2(TorchaudioTestCase): ...@@ -22,25 +22,48 @@ class TestConformerWav2Vec2(TorchaudioTestCase):
model(features, lengths) model(features, lengths)
@parameterized.expand([(torch.float32,), (torch.float64,)]) @parameterized.expand(
def test_cpu_smoke_test(self, dtype): [
model = conformer_wav2vec2_base() (conformer_wav2vec2_base, torch.float32, 64),
self._smoke_test(model, torch.device("cpu"), dtype) (conformer_wav2vec2_base, torch.float64, 64),
(emformer_hubert_base, torch.float32, 80),
@parameterized.expand([(torch.float32,), (torch.float64,)]) (emformer_hubert_base, torch.float64, 80),
]
)
def test_cpu_smoke_test(self, model, dtype, feature_dim):
model = model()
self._smoke_test(model, feature_dim, torch.device("cpu"), dtype)
@parameterized.expand(
[
(conformer_wav2vec2_base, torch.float32, 64),
(conformer_wav2vec2_base, torch.float64, 64),
(emformer_hubert_base, torch.float32, 80),
(emformer_hubert_base, torch.float64, 80),
]
)
@skipIfNoCuda @skipIfNoCuda
def test_cuda_smoke_test(self, dtype): def test_cuda_smoke_test(self, model, dtype, feature_dim):
model = conformer_wav2vec2_base() model = model()
self._smoke_test(model, torch.device("cuda"), dtype) self._smoke_test(model, feature_dim, torch.device("cuda"), dtype)
def test_extract_feature(self): @parameterized.expand(
model = conformer_wav2vec2_base() [
(conformer_wav2vec2_base, 64),
(emformer_hubert_base, 80),
]
)
def test_extract_feature(self, model, feature_dim):
model = model()
model.eval() model.eval()
batch_size, num_frames, in_features = 3, 1024, 64 batch_size, num_frames = 3, 1024
num_layers = len(model.encoder.conformer) if feature_dim == 64:
num_layers = len(model.encoder.conformer)
else:
num_layers = len(model.encoder.emformer.emformer_layers)
features = torch.randn(batch_size, num_frames, in_features) features = torch.randn(batch_size, num_frames, feature_dim)
lengths = torch.randint( lengths = torch.randint(
low=0, low=0,
high=num_frames, high=num_frames,
...@@ -63,12 +86,18 @@ class TestConformerWav2Vec2(TorchaudioTestCase): ...@@ -63,12 +86,18 @@ class TestConformerWav2Vec2(TorchaudioTestCase):
self.assertEqual(all_features[i], feats[i]) self.assertEqual(all_features[i], feats[i])
assert lengths_.shape == torch.Size([batch_size]) assert lengths_.shape == torch.Size([batch_size])
def test_zero_length(self): @parameterized.expand(
model = conformer_wav2vec2_base() [
(conformer_wav2vec2_base, 64),
(emformer_hubert_base, 80),
]
)
def test_zero_length(self, model, feature_dim):
model = model()
model.eval() model.eval()
batch_size, num_frames, in_features = 3, 1024, 64 batch_size, num_frames = 3, 1024
features = torch.randn(batch_size, num_frames, in_features) features = torch.randn(batch_size, num_frames, feature_dim)
input_lengths = torch.zeros(batch_size) input_lengths = torch.zeros(batch_size)
_, output_lengths = model(features, input_lengths) _, output_lengths = model(features, input_lengths)
self.assertEqual(torch.zeros_like(output_lengths), output_lengths) self.assertEqual(torch.zeros_like(output_lengths), output_lengths)
...@@ -76,12 +105,18 @@ class TestConformerWav2Vec2(TorchaudioTestCase): ...@@ -76,12 +105,18 @@ class TestConformerWav2Vec2(TorchaudioTestCase):
_, output_lengths = model.extract_features(features, input_lengths) _, output_lengths = model.extract_features(features, input_lengths)
self.assertEqual(torch.zeros_like(output_lengths), output_lengths) self.assertEqual(torch.zeros_like(output_lengths), output_lengths)
def test_torchscript_consistency(self): @parameterized.expand(
model = conformer_wav2vec2_base() [
(conformer_wav2vec2_base, 64),
(emformer_hubert_base, 80),
]
)
def test_torchscript_consistency(self, model, feature_dim):
model = model()
model.eval() model.eval()
batch_size, num_frames, in_features = 3, 1024, 64 batch_size, num_frames = 3, 1024
features = torch.randn(batch_size, num_frames, in_features) features = torch.randn(batch_size, num_frames, feature_dim)
lengths = torch.randint( lengths = torch.randint(
low=0, low=0,
high=num_frames, high=num_frames,
......
from ._conformer_wav2vec2 import conformer_wav2vec2_base, conformer_wav2vec2_model from ._conformer_wav2vec2 import conformer_wav2vec2_base, conformer_wav2vec2_model
from ._emformer_hubert import emformer_hubert_base, emformer_hubert_model
from .conv_emformer import ConvEmformer from .conv_emformer import ConvEmformer
from .rnnt import conformer_rnnt_base, conformer_rnnt_model from .rnnt import conformer_rnnt_base, conformer_rnnt_model
...@@ -8,4 +9,6 @@ __all__ = [ ...@@ -8,4 +9,6 @@ __all__ = [
"ConvEmformer", "ConvEmformer",
"conformer_wav2vec2_model", "conformer_wav2vec2_model",
"conformer_wav2vec2_base", "conformer_wav2vec2_base",
"emformer_hubert_base",
"emformer_hubert_model",
] ]
from typing import List, Optional, Tuple
import torch
from torchaudio.models import Wav2Vec2Model
from torchaudio.models.emformer import Emformer
from torchaudio.models.rnnt import _TimeReduction
class FeatureEncoder(torch.nn.Module):
"""Extract features from log-mel spectrogram input. Consists of linear layer and time reduction layer.
Args:
input_dim (int): The feature dimension of log-mel spectrogram feature.
output_dim (int): The feature dimension after linear layer.
use_bias (bool): If ``True``, enable bias parameter in the linear layer.
stride (int): Number of frames to merge for the output frame.
"""
def __init__(self, input_dim: int, output_dim: int, use_bias: bool, stride: int):
super().__init__()
self.linear = torch.nn.Linear(input_dim, output_dim, bias=use_bias)
self.time_reduction = _TimeReduction(stride)
def forward(
self, input: torch.Tensor, lengths: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
input (torch.Tensor): The log-mel spectrogram input.
Tensor with dimensions `(batch, time, input_dim)`.
lengths (torch.Tensor or None): Valid length of each input sample.
Tensor with dimension `(batch, )`.
Returns:
(torch.Tensor, torch.Tensor or None):
torch.Tensor
Returned feature Tensor after linear layer and time reduction layer.
Tensor with dimensions `(batch, time // stride, output_dim)`.
torch.Tensor or None
The reduced lengths Tensor.
"""
output = self.linear(input)
if lengths is None:
B, T, _ = input.shape
dummy_lengths = torch.full((B,), T)
output, _ = self.time_reduction(output, dummy_lengths)
else:
output, lengths = self.time_reduction(output, lengths)
return output, lengths
class EmformerEncoder(torch.nn.Module):
"""Emformer Encoder class for HuBERT pre-training. Consists of emformer module,
linear layer and layer normalization layer.
Args:
emformer (torch.nn.Module):
:py:class:`torchaudio.models.Emformer` module that consists of a list of emformer layers.
output_linear (torch.nn.Module):
Linear layer after emformer module.
layer_norm (torch.nn.Module or None, optional):
If ``None``, don't apply layer normalization to the output.
"""
def __init__(
self,
emformer: torch.nn.Module,
output_linear: torch.nn.Module,
layer_norm: Optional[torch.nn.Module] = None,
):
super().__init__()
self.emformer = emformer
self.output_linear = output_linear
self.layer_norm = layer_norm
def forward(
self,
input: torch.Tensor,
lengths: Optional[torch.Tensor],
) -> torch.Tensor:
"""
Args:
input (torch.Tensor): The input feature for emformer encoder.
Tensor with dimensions `(batch, time, feature_dim)`.
lengths (torch.Tensor or None): Valid length of each input sample.
Tensor with dimension `(batch, )`.
Returns:
torch.Tensor: The feature Tensor after emformer encoder.
"""
if lengths is None:
B, T, _ = input.shape
dummy_lengths = torch.full((B,), T)
output, _ = self.emformer(input, dummy_lengths)
else:
output, lengths = self.emformer(input, lengths)
output = self.output_linear(output)
if self.layer_norm is not None:
output = self.layer_norm(output)
return output
def extract_features(
self,
input: torch.Tensor,
lengths: Optional[torch.Tensor],
num_layers: Optional[int] = None,
) -> List[torch.Tensor]:
"""Extract output Tensors of the emformer layers.
Args:
input (torch.Tensor): The input feature for emformer encoder.
Tensor with dimensions `(batch, time, feature_dim)`.
lengths (torch.Tensor or None): Valid length of each input sample.
Tensor with dimension `(batch, )`.
num_layers (int or None, optional): If not ``None``, returns the first
`num_layers` layers of Tensors as the output, otherwise returns the
Tensors from all emformer layers.
Returns:
List[torch.Tensor]:
Output Tensors of selected emformer layers.
"""
if num_layers is not None:
if not 0 < num_layers <= len(self.emformer.emformer_layers):
raise ValueError(f"`num_layers` must be between [1, {len(self.emformer.emformer_layers)}]")
ret: List[torch.Tensor] = []
input = input.permute(1, 0, 2)
right_context = self.emformer._gen_right_context(input)
utterance = input[: input.size(0) - self.emformer.right_context_length]
attention_mask = self.emformer._gen_attention_mask(utterance)
mems = (
self.emformer.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1]
if self.emformer.use_mem
else torch.empty(0).to(dtype=input.dtype, device=input.device)
)
output = utterance
if lengths is None:
B, T, _ = input.shape
lengths = torch.full((B,), T)
for layer in self.emformer.emformer_layers:
output, right_context, mems = layer(output, lengths, right_context, mems, attention_mask)
ret.append(output.permute(1, 0, 2))
if num_layers is not None and len(ret) >= num_layers:
return ret
return ret
def _get_emformer_feature_extractor(input_dim: int, output_dim: int, use_bias: bool, stride: int) -> FeatureEncoder:
"""Construct FeatureEncoder for emformer model.
Args:
input_dim (int): The feature dimension of log-mel spectrogram feature.
output_dim (int): The feature dimension after linear layer.
use_bias (bool): If ``True``, enable bias parameter in the linear layer.
stride (int): Number of frames to merge for the output frame.
Returns:
FeatureEncoder: The resulting FeatureEncoder module.
"""
return FeatureEncoder(input_dim, output_dim, use_bias, stride)
def _get_emformer_encoder(
input_dim: int,
output_dim: int,
num_heads: int,
ffn_dim: int,
num_layers: int,
segment_length: int,
left_context_length: int,
right_context_length: int,
dropout: float,
activation: str,
max_memory_size: int,
weight_init_scale_strategy: Optional[str],
tanh_on_mem: bool,
) -> EmformerEncoder:
"""Construct EmformerEncoder for emformer model.
Args:
input_dim (int): The feature dimension of input Tensor.
output_dim (int): The feature dimension after EmformerEncoder.
num_heads (int): Number of attention heads in each Emformer layer.
ffn_dim: (int): Hidden layer dimension of feedforward network.
num_layers (int): Number of Emformer layers to instantiate.
segment_length (int): Length of each input segment.
left_context_length (int): Length of left context.
right_context_length (int): Length of right context.
dropout (float): Dropout probability.
activation (str): Activation function to use in each Emformer layer's
feedforward network. Must be one of ("relu", "gelu", "silu").
max_memory_size (int): Maximum number of memory elements to use.
weight_init_scale_strategy (str or None): Per-layer weight initialization scaling
strategy. Must be one of ("depthwise", "constant", ``None``).
tanh_on_mem (bool): If ``True``, applies tanh to memory elements.
Returns:
EmformerEncoder: The resulting EmformerEncoder module.
"""
emformer = Emformer(
input_dim=input_dim,
num_heads=num_heads,
ffn_dim=ffn_dim,
num_layers=num_layers,
segment_length=segment_length,
left_context_length=left_context_length,
right_context_length=right_context_length,
dropout=dropout,
activation=activation,
max_memory_size=max_memory_size,
weight_init_scale_strategy=weight_init_scale_strategy,
tanh_on_mem=tanh_on_mem,
)
output_linear = torch.nn.Linear(input_dim, output_dim)
return EmformerEncoder(emformer, output_linear)
def emformer_hubert_model(
extractor_input_dim: int,
extractor_output_dim: int,
extractor_use_bias: bool,
extractor_stride: int,
encoder_input_dim: int,
encoder_output_dim: int,
encoder_num_heads: int,
encoder_ffn_dim: int,
encoder_num_layers: int,
encoder_segment_length: int,
encoder_left_context_length: int,
encoder_right_context_length: int,
encoder_dropout: float,
encoder_activation: str,
encoder_max_memory_size: int,
encoder_weight_init_scale_strategy: Optional[str],
encoder_tanh_on_mem: bool,
) -> Wav2Vec2Model:
"""Build a custom Emformer HuBERT model.
Args:
extractor_input_dim (int): The input dimension for feature extractor.
extractor_output_dim (int): The output dimension after feature extractor.
extractor_use_bias (bool): If ``True``, enable bias parameter in the linear layer of feature extractor.
extractor_stride (int): Number of frames to merge for the output frame in feature extractor.
encoder_input_dim (int): The input dimension for Emformer layer.
encoder_output_dim (int): The output dimension after EmformerEncoder.
encoder_num_heads (int): Number of attention heads in each Emformer layer.
encoder_ffn_dim (int): Hidden layer dimension of feedforward network in Emformer.
encoder_num_layers (int): Number of Emformer layers to instantiate.
encoder_segment_length (int): Length of each input segment.
encoder_left_context_length (int): Length of left context.
encoder_right_context_length (int): Length of right context.
encoder_dropout (float): Dropout probability.
encoder_activation (str): Activation function to use in each Emformer layer's
feedforward network. Must be one of ("relu", "gelu", "silu").
encoder_max_memory_size (int): Maximum number of memory elements to use.
encoder_weight_init_scale_strategy (str or None): Per-layer weight initialization scaling
strategy. Must be one of ("depthwise", "constant", ``None``).
encoder_tanh_on_mem (bool): If ``True``, applies tanh to memory elements.
Returns:
Wav2Vec2Model:
The resulting :py:class:`torchaudio.models.Wav2Vec2Model` model
with a :py:class:`torchaudio.models.Emformer` encoder.
"""
feature_extractor = _get_emformer_feature_extractor(
extractor_input_dim, extractor_output_dim, extractor_use_bias, extractor_stride
)
emformer = _get_emformer_encoder(
encoder_input_dim,
encoder_output_dim,
encoder_num_heads,
encoder_ffn_dim,
encoder_num_layers,
encoder_segment_length,
encoder_left_context_length,
encoder_right_context_length,
encoder_dropout,
encoder_activation,
encoder_max_memory_size,
encoder_weight_init_scale_strategy,
encoder_tanh_on_mem,
)
return Wav2Vec2Model(feature_extractor, emformer)
def emformer_hubert_base(
extractor_input_dim: int = 80,
extractor_output_dim: int = 128,
encoder_dropout: float = 0.1,
) -> Wav2Vec2Model:
"""Build Emformer HuBERT Model with 20 Emformer layers.
Args:
extractor_input_dim (int, optional): The input dimension for feature extractor. (Default: 80)
extractor_output_dim (int, optional): The output dimension after feature extractor. (Default: 128)
encoder_dropout (float, optional): Dropout probability in Emformer. (Default: 0.1)
Returns:
Wav2Vec2Model:
The resulting :py:class:`torchaudio.models.Wav2Vec2Model` model
with a :py:class:`torchaudio.models.Emformer` encoder.
"""
return emformer_hubert_model(
extractor_input_dim=extractor_input_dim,
extractor_output_dim=extractor_output_dim,
extractor_use_bias=False,
extractor_stride=4,
encoder_input_dim=512,
encoder_output_dim=1024,
encoder_num_heads=8,
encoder_ffn_dim=2048,
encoder_num_layers=20,
encoder_segment_length=4,
encoder_left_context_length=30,
encoder_right_context_length=1,
encoder_dropout=encoder_dropout,
encoder_activation="gelu",
encoder_max_memory_size=0,
encoder_weight_init_scale_strategy="depthwise",
encoder_tanh_on_mem=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment