Commit 8bde6a54 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Add conformer wav2vec2 pretrain model (#2827)

Summary:
modeled after [paper](https://arxiv.org/pdf/2110.07313.pdf) and internal flow f288347302

internal comparison tests: D40080919

Pull Request resolved: https://github.com/pytorch/audio/pull/2827

Reviewed By: nateanl

Differential Revision: D41569046

Pulled By: carolineechen

fbshipit-source-id: 43c5313074af05972d93da55b2029c746b75c380
parent b0795ebe
...@@ -33,6 +33,13 @@ ConvEmformer ...@@ -33,6 +33,13 @@ ConvEmformer
.. automethod:: infer .. automethod:: infer
ConformerWav2Vec2PretrainModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: ConformerWav2Vec2PretrainModel
.. automethod:: forward
conformer_wav2vec2_model conformer_wav2vec2_model
~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -42,3 +49,18 @@ conformer_wav2vec2_base ...@@ -42,3 +49,18 @@ conformer_wav2vec2_base
~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: conformer_wav2vec2_base .. autofunction:: conformer_wav2vec2_base
conformer_wav2vec2_pretrain_model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: conformer_wav2vec2_pretrain_model
conformer_wav2vec2_pretrain_base
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: conformer_wav2vec2_pretrain_base
conformer_wav2vec2_pretrain_large
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: conformer_wav2vec2_pretrain_large
import torch
from parameterized import parameterized
from torchaudio.prototype.models import (
conformer_wav2vec2_base,
conformer_wav2vec2_pretrain_base,
conformer_wav2vec2_pretrain_large,
)
from torchaudio_unittest.common_utils import nested_params, skipIfNoCuda, torch_script, TorchaudioTestCase
class TestConformerWav2Vec2(TorchaudioTestCase):
def _smoke_test(self, model, device, dtype):
model = model.to(device=device, dtype=dtype)
model = model.eval()
batch_size, num_frames, in_features = 3, 1024, 64
features = torch.randn(batch_size, num_frames, in_features, device=device, dtype=dtype)
lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
device=device,
)
model(features, lengths)
@parameterized.expand([(torch.float32,), (torch.float64,)])
def test_cpu_smoke_test(self, dtype):
model = conformer_wav2vec2_base()
self._smoke_test(model, torch.device("cpu"), dtype)
@parameterized.expand([(torch.float32,), (torch.float64,)])
@skipIfNoCuda
def test_cuda_smoke_test(self, dtype):
model = conformer_wav2vec2_base()
self._smoke_test(model, torch.device("cuda"), dtype)
@nested_params(
[conformer_wav2vec2_pretrain_base, conformer_wav2vec2_pretrain_large],
[torch.float32, torch.float64],
)
def test_pretrain_cpu_smoke_test(self, model, dtype):
model = model()
self._smoke_test(model, torch.device("cpu"), dtype)
@nested_params(
[conformer_wav2vec2_pretrain_base, conformer_wav2vec2_pretrain_large],
[torch.float32, torch.float64],
)
@skipIfNoCuda
def test_pretrain_cuda_smoke_test(self, model, dtype):
model = model()
self._smoke_test(model, torch.device("cuda"), dtype)
def test_extract_feature(self):
model = conformer_wav2vec2_base()
model.eval()
batch_size, num_frames, in_features = 3, 1024, 64
num_layers = len(model.encoder.conformer)
features = torch.randn(batch_size, num_frames, in_features)
lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
)
all_features, lengths_ = model.extract_features(features, lengths, num_layers=None)
assert len(all_features) == num_layers
for feats in all_features:
assert feats.ndim == 3
assert feats.shape[0] == batch_size
assert lengths_.shape == torch.Size([batch_size])
for l in range(1, num_layers + 1):
feats, lengths_ = model.extract_features(features, lengths, num_layers=l)
assert len(feats) == l
for i in range(l):
self.assertEqual(all_features[i], feats[i])
assert lengths_.shape == torch.Size([batch_size])
def test_zero_length(self):
model = conformer_wav2vec2_base()
model.eval()
batch_size, num_frames, in_features = 3, 1024, 64
features = torch.randn(batch_size, num_frames, in_features)
input_lengths = torch.zeros(batch_size)
_, output_lengths = model(features, input_lengths)
self.assertEqual(torch.zeros_like(output_lengths), output_lengths)
_, output_lengths = model.extract_features(features, input_lengths)
self.assertEqual(torch.zeros_like(output_lengths), output_lengths)
def test_torchscript_consistency(self):
model = conformer_wav2vec2_base()
model.eval()
batch_size, num_frames, in_features = 3, 1024, 64
features = torch.randn(batch_size, num_frames, in_features)
lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
)
ref_out, ref_len = model(features, lengths)
scripted = torch_script(model)
hyp_out, hyp_len = scripted(features, lengths)
self.assertEqual(hyp_out, ref_out)
self.assertEqual(hyp_len, ref_len)
import torch import torch
from parameterized import parameterized from parameterized import parameterized
from torchaudio.prototype.models import conformer_wav2vec2_base, emformer_hubert_base from torchaudio.prototype.models import conformer_wav2vec2_base, conformer_wav2vec2_pretrain_base, emformer_hubert_base
from torchaudio_unittest.common_utils import nested_params, skipIfNoCuda, torch_script, TorchaudioTestCase from torchaudio_unittest.common_utils import nested_params, skipIfNoCuda, torch_script, TorchaudioTestCase
...@@ -23,7 +23,7 @@ class TestSSLModel(TorchaudioTestCase): ...@@ -23,7 +23,7 @@ class TestSSLModel(TorchaudioTestCase):
model(features, lengths) model(features, lengths)
@nested_params( @nested_params(
[(conformer_wav2vec2_base, 64), (emformer_hubert_base, 80)], [(conformer_wav2vec2_base, 64), (conformer_wav2vec2_pretrain_base, 64), (emformer_hubert_base, 80)],
[torch.float32, torch.float64], [torch.float32, torch.float64],
) )
def test_cpu_smoke_test(self, model_feature_dim, dtype): def test_cpu_smoke_test(self, model_feature_dim, dtype):
...@@ -32,7 +32,7 @@ class TestSSLModel(TorchaudioTestCase): ...@@ -32,7 +32,7 @@ class TestSSLModel(TorchaudioTestCase):
self._smoke_test(model, feature_dim, torch.device("cpu"), dtype) self._smoke_test(model, feature_dim, torch.device("cpu"), dtype)
@nested_params( @nested_params(
[(conformer_wav2vec2_base, 64), (emformer_hubert_base, 80)], [(conformer_wav2vec2_base, 64), (conformer_wav2vec2_pretrain_base, 64), (emformer_hubert_base, 80)],
[torch.float32, torch.float64], [torch.float32, torch.float64],
) )
@skipIfNoCuda @skipIfNoCuda
......
from ._conformer_wav2vec2 import conformer_wav2vec2_base, conformer_wav2vec2_model from ._conformer_wav2vec2 import (
conformer_wav2vec2_base,
conformer_wav2vec2_model,
conformer_wav2vec2_pretrain_base,
conformer_wav2vec2_pretrain_large,
conformer_wav2vec2_pretrain_model,
ConformerWav2Vec2PretrainModel,
)
from ._emformer_hubert import emformer_hubert_base, emformer_hubert_model from ._emformer_hubert import emformer_hubert_base, emformer_hubert_model
from .conv_emformer import ConvEmformer from .conv_emformer import ConvEmformer
from .rnnt import conformer_rnnt_base, conformer_rnnt_model from .rnnt import conformer_rnnt_base, conformer_rnnt_model
...@@ -9,6 +16,10 @@ __all__ = [ ...@@ -9,6 +16,10 @@ __all__ = [
"ConvEmformer", "ConvEmformer",
"conformer_wav2vec2_model", "conformer_wav2vec2_model",
"conformer_wav2vec2_base", "conformer_wav2vec2_base",
"conformer_wav2vec2_pretrain_model",
"conformer_wav2vec2_pretrain_base",
"conformer_wav2vec2_pretrain_large",
"ConformerWav2Vec2PretrainModel",
"emformer_hubert_base", "emformer_hubert_base",
"emformer_hubert_model", "emformer_hubert_model",
] ]
...@@ -2,20 +2,129 @@ from typing import List, Optional, Tuple, Union ...@@ -2,20 +2,129 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
from torch.nn import Module from torch.nn import Module, ModuleList
from torchaudio.models import Wav2Vec2Model from torchaudio.models import Wav2Vec2Model
from torchaudio.models.conformer import ConformerLayer from torchaudio.models.conformer import ConformerLayer
from torchaudio.models.rnnt import _TimeReduction from torchaudio.models.rnnt import _TimeReduction
from torchaudio.models.wav2vec2 import components from torchaudio.models.wav2vec2 import components
def _buffered_arange(max) -> Tensor:
"""Compute arange using a buffered tensor across function calls.
Produces same result as torch.arange(end=max).
Args:
max (int): Ending value for arange.
"""
if not hasattr(_buffered_arange, "buf"):
_buffered_arange.buf = torch.LongTensor()
if max > _buffered_arange.buf.numel():
_buffered_arange.buf.resize_(max)
torch.arange(max, out=_buffered_arange.buf)
return _buffered_arange.buf[:max]
def _sample_negatives(input: Tensor, num_negatives: int, cross_sample_negatives: int) -> Tuple[Tensor, Tensor]:
"""Sample negative examples from masked input.
Args:
input (Tensor): Tensor of dimension `(batch, frame, dim)`.
num_negatives (int): Number of negative examples to sample.
cross_sample_negatives (int): Number of negative examples to cross sample.
Returns:
(Tensor, Tensor):
Tensor
The negative samples.
Tensor
The indices of the negative samples.
"""
if num_negatives == 0 and cross_sample_negatives == 0:
return (
torch.zeros(0).to(input.device, input.dtype),
torch.zeros(0).to(input.device, input.dtype),
)
B, T, D = input.shape
input = input.view(-1, D)
cross_high = T * B
high = T
assert high > 1
if num_negatives > 0:
tszs = _buffered_arange(T).unsqueeze(-1).expand(-1, num_negatives).flatten()
neg_idxs = torch.randint(low=0, high=high - 1, size=(B, num_negatives * T))
neg_idxs[neg_idxs >= tszs] += 1
if cross_sample_negatives > 0:
tszs = _buffered_arange(T).unsqueeze(-1).expand(-1, cross_sample_negatives).flatten()
cross_neg_idxs = torch.randint(low=0, high=cross_high - 1, size=(B, cross_sample_negatives * T))
cross_neg_idxs[cross_neg_idxs >= tszs] += 1
if num_negatives > 0:
neg_idxs = neg_idxs + (torch.arange(B).unsqueeze(1) * high)
else:
neg_idxs = cross_neg_idxs
if cross_sample_negatives > 0 and num_negatives > 0:
neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1)
negs = input[neg_idxs.view(-1)]
negs = negs.view(B, T, num_negatives + cross_sample_negatives, D).permute(2, 0, 1, 3) # NxBxCxT
return negs, neg_idxs
class NegativeSampler(Module):
r"""Applies preprocessing to input and then computes negative sampling.
Args:
preprocessor (nn.Module): Transforms input tensor prior to negative sampling.
num_negatives (int): Number of negative examples to sample.
cross_sample_negatives (int): Number of negative examples to cross sample.
"""
def __init__(
self,
preprocessor: Module,
num_negatives: int,
cross_sample_negatives: int,
):
super().__init__()
self.preprocessor = preprocessor
self.num_negatives = num_negatives
self.cross_sample_negatives = cross_sample_negatives
def forward(self, input: Tensor) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
"""
Args:
input (Tensor): Tensor of dimension `(B, T, D)`.
Returns:
(Tensor, Tensor, Optional[Tensor]):
Tensor
The input tensor after preprocessing, prior to being sampled.
Tensor
The negative samples.
Tensor
The indices of the negative samples.
"""
preprocessed = self.preprocessor(input)
negs, neg_idxs = _sample_negatives(preprocessed, self.num_negatives, self.cross_sample_negatives)
return preprocessed, negs, neg_idxs
class FeatureEncoder(Module): class FeatureEncoder(Module):
"""Feature Encoder class, consisting of time reduction and linear layer. """Feature Encoder class, consisting of time reduction and linear layer.
Args: Args:
stride (int): number of frames to merge for the output frame stride (int): Number of frames to merge for the output frame.
input_dim (int): input dimension of the tensor input_dim (int): Input dimension of the tensor.
output_dim (int): output dimension of the tensor output_dim (int): Output dimension of the tensor.
""" """
def __init__(self, input_dim: int, output_dim: int, stride: int): def __init__(self, input_dim: int, output_dim: int, stride: int):
...@@ -37,7 +146,7 @@ class FeatureEncoder(Module): ...@@ -37,7 +146,7 @@ class FeatureEncoder(Module):
Returns: Returns:
(Tensor, Optional[Tensor]): (Tensor, Optional[Tensor]):
Tensor: output sequence after undergoing time reduction and linear projection. Tensor: output sequence after undergoing time reduction and linear projection.
Shape ``(B, T // stride, D * stride) Shape ``(B, T // stride, D * stride).
Optional[Tensor]: output lengths of shape ``(B,)`` if lengths parameter is provided, Optional[Tensor]: output lengths of shape ``(B,)`` if lengths parameter is provided,
otherwise `None`. otherwise `None`.
""" """
...@@ -58,15 +167,15 @@ class ConformerEncoder(Module): ...@@ -58,15 +167,15 @@ class ConformerEncoder(Module):
Args: Args:
feature_projection (nn.Module): feature_projection (nn.Module):
Projects feature to encoder dimension Projects feature to encoder dimension.
conformer (nn.ModuleList) conformer (nn.ModuleList)
List of Conformer layers List of Conformer layers.
""" """
def __init__( def __init__(
self, self,
feature_projection: Module, feature_projection: Module,
conformer: nn.ModuleList, conformer: ModuleList,
): ):
super().__init__() super().__init__()
self.feature_projection = feature_projection self.feature_projection = feature_projection
...@@ -111,7 +220,7 @@ class ConformerEncoder(Module): ...@@ -111,7 +220,7 @@ class ConformerEncoder(Module):
) -> Tensor: ) -> Tensor:
""" """
Args: Args:
features (Tensor): Tensor of features of shape ``(B, T, D)`` features (Tensor): Tensor of features of shape ``(B, T, D)``.
lengths (Tensor or None, optional): Valid length of each input sample. shape: ``(B, )``. lengths (Tensor or None, optional): Valid length of each input sample. shape: ``(B, )``.
Returns: Returns:
...@@ -132,17 +241,96 @@ class ConformerEncoder(Module): ...@@ -132,17 +241,96 @@ class ConformerEncoder(Module):
"""Returns the list of outputs from the intermediate layers of conformer block in the encoder. """Returns the list of outputs from the intermediate layers of conformer block in the encoder.
Args: Args:
features (Tensor): Tensor of features of shape ``(B, T, D)`` features (Tensor): Tensor of features of shape ``(B, T, D)``.
lengths (Tensor or None, optional): Valid length of each input sample. shape: ``(B, )``. lengths (Tensor or None, optional): Valid length of each input sample. shape: ``(B, )``.
Returns: Returns:
List[Tensor]: List[Tensor]:
Features from requested layers. Each Tensor is of shape: `(batch, time frame, feature dimension)` Features from requested layers. Each Tensor is of shape: `(batch, time frame, feature dimension)`.
""" """
x, masks = self._preprocess(features, lengths) x, masks = self._preprocess(features, lengths)
return self._get_intermediate_outputs(x, mask=masks, num_layers=num_layers) return self._get_intermediate_outputs(x, mask=masks, num_layers=num_layers)
class ConformerWav2Vec2PretrainModel(Module):
"""Conformer Wav2Vec2 pre-train model for training from scratch.
Note:
To build the model, please use one of the factory functions,
:py:func:`conformer_wav2vec2_base` or :py:func:`conformer_wav2vec2_large`
Args:
wav2vec2 (nn.Module):
Conformer based Wav2Vec2 model, including feature extractor and conformer encoder components.
mask_generator (nn.Module):
Mask generator that generates the mask for masked prediction during training.
negative_sampler (nn.Module):
Negative sampler to apply after masking.
"""
def __init__(
self,
wav2vec2: Wav2Vec2Model,
mask_generator: Module,
negative_sampler: Module,
):
super().__init__()
self.wav2vec2 = wav2vec2
self.mask_generator = mask_generator
self.negative_sampler = negative_sampler
def forward(
self,
features: Tensor,
audio_lengths: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor]:
"""
Args:
features (Tensor):
Tensor of audio features of shape `(batch, frame, dim)`.
audio_lengths (Tensor or None, optional):
Tensor of valid length of each valid auidio in the batch.
shape: `(batch, )` (Default: ``None``)
Returns:
(Tensor, Optional[Tensor], Tensor, Tensor, Tensor, Tensor):
Tensor
The masked sequences of probability distribution of shape `(batch, frame dim)`.
Tensor or None
If ``lengths`` argument was provided, a Tensor of shape `(batch, )` representing
valid length in time axis is returns.
Tensor
The mask indices.
Tensor
The targets, prior to negative sampling.
Tensor
The negative samples.
Tensor
The indices of the negative samples.
"""
x, lengths = self.wav2vec2.feature_extractor(features, audio_lengths)
if lengths is not None:
padding_mask = components._get_padding_mask(x, lengths)
else:
padding_mask = None
x = self.wav2vec2.encoder.feature_projection.layer_norm(x)
x = self.wav2vec2.encoder.feature_projection.dropout(x)
x, mask_idxs = self.mask_generator(x, padding_mask)
targets, negs, neg_idxs = self.negative_sampler(x)
x = self.wav2vec2.encoder.feature_projection.projection(x)
x = x.transpose(0, 1)
for conformer_layer in self.wav2vec2.encoder.conformer:
x = conformer_layer(x, padding_mask)
x = x.transpose(0, 1)
return x, lengths, mask_idxs, targets, negs, neg_idxs
################################################################################ ################################################################################
def _get_conformer_feature_extractor( def _get_conformer_feature_extractor(
input_dim: int, input_dim: int,
...@@ -152,12 +340,12 @@ def _get_conformer_feature_extractor( ...@@ -152,12 +340,12 @@ def _get_conformer_feature_extractor(
"""Construct Feature Extractor """Construct Feature Extractor
Args: Args:
input_dim (int): Input dimension of features input_dim (int): Input dimension of features.
output_dim (int): Output dimension after feature extraction output_dim (int): Output dimension after feature extraction.
stride (int): Stride used in Time Reduction layer of feature extractor stride (int): Stride used in Time Reduction layer of feature extractor.
Returns: Returns:
FeatureEncoder: The resulting feature extraction FeatureEncoder: The resulting feature extraction.
""" """
return FeatureEncoder(input_dim, output_dim, stride) return FeatureEncoder(input_dim, output_dim, stride)
...@@ -218,7 +406,30 @@ def _get_conformer_encoder( ...@@ -218,7 +406,30 @@ def _get_conformer_encoder(
) )
conformer_layers.append(layer) conformer_layers.append(layer)
return ConformerEncoder(feature_projection, nn.ModuleList(conformer_layers)) return ConformerEncoder(feature_projection, ModuleList(conformer_layers))
def _get_conformer_negativer_sampler(
input_dim: int,
output_dim: int,
num_negatives: int,
cross_sample_negatives: int,
) -> NegativeSampler:
"""Build custom NegativeSampler module, including linear layer and negative sampling.
Args:
input_dim (int): Dimension of input after feature extraction.
output_dim (int): Dimension of embedding for use in negative sampling. Same as the
embedding in the feature projection.
num_negatives (int): Number of negatives to sample.
cross_sample_negatives (int): Number of cross sampled negatives.
Returns:
NegativeSampler:
The resulting negative sampler module.
"""
preprocessor = nn.Linear(input_dim, output_dim)
return NegativeSampler(preprocessor, num_negatives, cross_sample_negatives)
def conformer_wav2vec2_model( def conformer_wav2vec2_model(
...@@ -302,7 +513,7 @@ def conformer_wav2vec2_base( ...@@ -302,7 +513,7 @@ def conformer_wav2vec2_base(
Returns: Returns:
Wav2Vec2Model: Wav2Vec2Model:
The resulting wav2vec2 model with a conformer encoder and ``base`` configuration. The resulting wav2vec2 model with a conformer encoder and ``base`` configuration.
""" """
return conformer_wav2vec2_model( return conformer_wav2vec2_model(
extractor_input_dim=extractor_input_dim, extractor_input_dim=extractor_input_dim,
...@@ -318,3 +529,261 @@ def conformer_wav2vec2_base( ...@@ -318,3 +529,261 @@ def conformer_wav2vec2_base(
encoder_convolution_first=True, encoder_convolution_first=True,
encoder_use_group_norm=True, encoder_use_group_norm=True,
) )
def conformer_wav2vec2_pretrain_model(
extractor_input_dim: int,
extractor_output_dim: int,
extractor_stride: int,
encoder_embed_dim: int,
encoder_projection_dropout: float,
encoder_num_layers: int,
encoder_num_heads: int,
encoder_ff_interm_features: int,
encoder_depthwise_conv_kernel_size: int,
encoder_dropout: float,
encoder_convolution_first: bool,
encoder_use_group_norm: bool,
mask_prob: float,
mask_selection: str,
mask_other: float,
mask_length: int,
no_mask_overlap: bool,
mask_min_space: int,
mask_channel_prob: float,
mask_channel_selection: str,
mask_channel_other: float,
mask_channel_length: int,
no_mask_channel_overlap: bool,
mask_channel_min_space: int,
num_negatives: int,
cross_sample_negatives: int,
) -> ConformerWav2Vec2PretrainModel:
"""Build a custom Conformer Wav2Vec2 Model for pre-training
Args:
extractor_input_dim (int): Input dimension of the features.
extractor_output_dim (int): Output dimension after feature extraction.
extractor_stride (int):
Stride used in time reduction layer of feature extraction.
encoder_embed_dim (int):
The dimension of the embedding in the feature projection.
encoder_projection_dropout (float):
The dropout probability applied after the input feature is projected to
``embed_dim``
encoder_num_layers (int):
Number of Conformer layers in the encoder.
encoder_num_heads (int):
Number of heads in each Conformer layer.
encoder_ff_interm_features (int):
Hidden layer dimension of the feedforward network in each Conformer layer.
encoder_depthwise_conv_kernel_size (int or List[int]):
List of kernel sizes corresponding to each of the Conformer layers.
If int is provided, all layers will have the same kernel size.
encoder_dropout (float):
Dropout probability in each Conformer layer.
encoder_convolution_first (bool):
Whether to apply the convolution module ahead of the attention module
in each Conformer layer.
encoder_use_group_norm (bool):
Whether to use ``GroupNorm`` rather than ``BatchNorm1d`` in the convolution
module in each Conformer layer.
mask_prob (float):
Probability for each token to be chosen as start of the span to be masked.
mask_selection (str)
How to choose the mask length. Options: [``static``, ``uniform``, ``normal``, ``poisson``].
mask_other (float):
Secondary mask argument (used for more complex distributions).
mask_length (int):
The lengths of the mask.
no_mask_overlap (bool):
Whether to allow masks to overlap.
mask_min_space (int):
Minimum space between spans (if no overlap is enabled).
mask_channel_prob: (float):
The probability of replacing a feature with 0.
mask_channel_selection (str):
How to choose the mask length for channel masking.
Options: [``static``, ``uniform``, ``normal``, ``poisson``].
mask_channel_other (float):
Secondary mask argument for channel masking (used for more complex distributions).
mask_channel_length (int):
Minimum space between spans (if no overlap is enabled) for channel masking.
no_mask_channel_overlap (bool):
Whether to allow channel masks to overlap.
mask_channel_min_space (int):
Minimum space between spans for channel masking (if no overlap is enabled).
num_negatives (int):
Number of negatives to sample.
cross_sample_negatives (int):
Number of cross sampled negatives.
Returns:
ConformerWav2Vec2PretrainModel:
The resulting model.
"""
wav2vec2 = conformer_wav2vec2_model(
extractor_input_dim,
extractor_output_dim,
extractor_stride,
encoder_embed_dim,
encoder_projection_dropout,
encoder_num_layers,
encoder_num_heads,
encoder_ff_interm_features,
encoder_depthwise_conv_kernel_size,
encoder_dropout,
encoder_convolution_first,
encoder_use_group_norm,
)
mask_generator = components.MaskGenerator(
extractor_output_dim,
mask_prob,
mask_selection,
mask_other,
mask_length,
no_mask_overlap,
mask_min_space,
mask_channel_prob,
mask_channel_selection,
mask_channel_other,
mask_channel_length,
no_mask_channel_overlap,
mask_channel_min_space,
)
negative_sampler = _get_conformer_negativer_sampler(
extractor_output_dim,
encoder_embed_dim,
num_negatives,
cross_sample_negatives,
)
return ConformerWav2Vec2PretrainModel(
wav2vec2=wav2vec2,
mask_generator=mask_generator,
negative_sampler=negative_sampler,
)
def conformer_wav2vec2_pretrain_base(
extractor_input_dim: int = 64,
extractor_output_dim: int = 256,
encoder_projection_dropout: float = 0.0,
mask_prob: float = 0.3,
mask_length: int = 3,
num_negatives: int = 100,
cross_sample_negatives: int = 0,
) -> ConformerWav2Vec2PretrainModel:
"""Build Conformer Wav2Vec2 Model for pre-training with "small" architecture from
*Conformer-Based Self-Supervised Learning for Non-Speech Audio Tasks* :cite:`conformerssl`
Args:
extractor_input_dim (int, optional): Input dimension of the features. (Default: 64)
extractor_output_dim (int, optional): Output dimension after feature extraction. (Default: 256)
encoder_projection_dropout (float, optional):
The dropout probability applied after the input feature is projected to
``embed_dim`. (Default: 0.0)
mask_prob (float, optional):
Probability for each token to be chosen as start of the span to be masked. (Default: 0.3)
mask_length (int, optional):
The lengths of the mask. (Default: 3)
num_negatives (int, optional):
Number of sampled negatives. (Default: 0)
cross_sample_negatives (int, optional):
Number of cross sampled negatives. (Default: 0)
Returns:
ConformerWav2Vec2PretrainModel:
The resulting model.
"""
return conformer_wav2vec2_pretrain_model(
extractor_input_dim=extractor_input_dim,
extractor_output_dim=extractor_output_dim,
extractor_stride=4,
encoder_embed_dim=256,
encoder_projection_dropout=encoder_projection_dropout,
encoder_num_layers=12,
encoder_num_heads=8,
encoder_ff_interm_features=1024,
encoder_depthwise_conv_kernel_size=[31] + [15] * 11,
encoder_dropout=0.1,
encoder_convolution_first=True,
encoder_use_group_norm=True,
mask_prob=mask_prob,
mask_selection="static",
mask_other=0.0,
mask_length=mask_length,
no_mask_overlap=False,
mask_min_space=0,
mask_channel_prob=0,
mask_channel_selection="static",
mask_channel_other=0,
mask_channel_length=10,
no_mask_channel_overlap=False,
mask_channel_min_space=1,
num_negatives=num_negatives,
cross_sample_negatives=cross_sample_negatives,
)
def conformer_wav2vec2_pretrain_large(
extractor_input_dim: int = 64,
extractor_output_dim: int = 256,
encoder_projection_dropout: float = 0.0,
mask_prob: float = 0.3,
mask_length: int = 3,
num_negatives: int = 100,
cross_sample_negatives: int = 0,
) -> ConformerWav2Vec2PretrainModel:
"""Build Conformer Wav2Vec2 Model for pre-training with "large" architecture from
*Conformer-Based Slef-Supervised Learning for Non-Speech Audio Tasks* :cite:`conformerssl`
Args:
extractor_input_dim (int, optional): Input dimension of the features. (Default: 64)
extractor_output_dim (int, optional): Output dimension after feature extraction. (Default: 256)
encoder_projection_dropout (float, optional):
The dropout probability applied after the input feature is projected to
``embed_dim`. (Default: 0.0)
mask_prob (float, optional):
Probability for each token to be chosen as start of the span to be masked. (Default: 0.3)
mask_length (int, optional):
The lengths of the mask. (Default: 3)
num_negatives (int, optional):
Number of sampled negatives. (Default: 0)
cross_sample_negatives (int, optional):
Number of cross sampled negatives. (Default: 0)
Returns:
ConformerWav2Vec2PretrainModel:
The resulting model.
"""
return conformer_wav2vec2_pretrain_model(
extractor_input_dim=extractor_input_dim,
extractor_output_dim=extractor_output_dim,
extractor_stride=4,
encoder_embed_dim=768,
encoder_projection_dropout=encoder_projection_dropout,
encoder_num_layers=12,
encoder_num_heads=12,
encoder_ff_interm_features=1024,
encoder_depthwise_conv_kernel_size=[31] + [15] * 11,
encoder_dropout=0.1,
encoder_convolution_first=True,
encoder_use_group_norm=True,
mask_prob=mask_prob,
mask_selection="static",
mask_other=0.0,
mask_length=mask_length,
no_mask_overlap=False,
mask_min_space=0,
mask_channel_prob=0,
mask_channel_selection="static",
mask_channel_other=0,
mask_channel_length=10,
no_mask_channel_overlap=False,
mask_channel_min_space=1,
num_negatives=num_negatives,
cross_sample_negatives=cross_sample_negatives,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment