Commit 37a2555f authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add HuBERT pretrain model to enable training from scratch (#2064)

Summary:
- Add three factory functions:`hubert_pretrain_base`, `hubert_pretrain_large`, and `hubert_pretrain_xlarge`, to enable the HuBERT model to train from scratch.
- Add `num_classes` argument to `hubert_pretrain_base` factory function because the base model has two iterations of training, the first iteration the `num_cluster` is 100, in the second iteration `num_cluster` is 500.
- The model takes `waveforms`, `labels`, and `lengths` as inputs
- The model generates the last layer of transformer embedding, `logit_m`, `logit_u` as the outputs.

Pull Request resolved: https://github.com/pytorch/audio/pull/2064

Reviewed By: hwangjeff, mthrok

Differential Revision: D33338587

Pulled By: nateanl

fbshipit-source-id: 534bc17c576c5f344043d8ba098204b8da6e630a
parent 7bf04d1e
......@@ -59,6 +59,13 @@ Wav2Vec2Model
.. automethod:: forward
HuBERTPretrainModel
^^^^^^^^^^^^^^^^^^^
.. autoclass:: HuBERTPretrainModel
.. automethod:: forward
Factory Functions
-----------------
......@@ -98,6 +105,26 @@ hubert_xlarge
.. autofunction:: hubert_xlarge
hubert_pretrain_model
^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: hubert_pretrain_model
hubert_pretrain_base
^^^^^^^^^^^^^^^^^^^^
.. autofunction:: hubert_pretrain_base
hubert_pretrain_large
^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: hubert_pretrain_large
hubert_pretrain_xlarge
^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: hubert_pretrain_xlarge
Utility Functions
-----------------
......
......@@ -4,6 +4,7 @@ from .tacotron2 import Tacotron2
from .wav2letter import Wav2Letter
from .wav2vec2 import (
Wav2Vec2Model,
HuBERTPretrainModel,
wav2vec2_model,
wav2vec2_base,
wav2vec2_large,
......@@ -11,6 +12,10 @@ from .wav2vec2 import (
hubert_base,
hubert_large,
hubert_xlarge,
hubert_pretrain_model,
hubert_pretrain_base,
hubert_pretrain_large,
hubert_pretrain_xlarge,
)
from .wavernn import WaveRNN
......@@ -20,6 +25,7 @@ __all__ = [
"ConvTasNet",
"DeepSpeech",
"Wav2Vec2Model",
"HuBERTPretrainModel",
"wav2vec2_model",
"wav2vec2_base",
"wav2vec2_large",
......@@ -27,5 +33,9 @@ __all__ = [
"hubert_base",
"hubert_large",
"hubert_xlarge",
"hubert_pretrain_model",
"hubert_pretrain_base",
"hubert_pretrain_large",
"hubert_pretrain_xlarge",
"Tacotron2",
]
from . import utils
from .model import (
Wav2Vec2Model,
HuBERTPretrainModel,
wav2vec2_model,
wav2vec2_base,
wav2vec2_large,
......@@ -8,10 +9,15 @@ from .model import (
hubert_base,
hubert_large,
hubert_xlarge,
hubert_pretrain_model,
hubert_pretrain_base,
hubert_pretrain_large,
hubert_pretrain_xlarge,
)
__all__ = [
"Wav2Vec2Model",
"HuBERTPretrainModel",
"wav2vec2_model",
"wav2vec2_base",
"wav2vec2_large",
......@@ -19,5 +25,9 @@ __all__ = [
"hubert_base",
"hubert_large",
"hubert_xlarge",
"hubert_pretrain_model",
"hubert_pretrain_base",
"hubert_pretrain_large",
"hubert_pretrain_xlarge",
"utils",
]
......@@ -3,7 +3,7 @@ from typing import Optional, Tuple, List
import torch
from torch import Tensor, nn
from torch.nn import Module
from torch.nn import Module, Parameter
_LG = logging.getLogger(__name__)
......@@ -713,3 +713,327 @@ def _get_encoder(
layer_drop=layer_drop,
)
return Encoder(feature_projection, transformer)
def _compute_mask_indices(
shape: Tuple[int, int],
padding_mask: Optional[Tensor],
mask_prob: float,
mask_length: int,
mask_type: str = "static",
mask_other: float = 0.0,
min_masks: int = 0,
no_overlap: bool = False,
min_space: int = 0,
) -> Tensor:
"""Computes random mask spans for a given shape.
Args:
shape (int, int): The shape for which to compute masks.
The first element is batch size and second is the number of frames.
padding_mask (Tensor or None): The padding mask of the same dimension as shape,
which will prevent masking padded elements.
mask_prob (float): Probability for each token to be chosen as start of the span to be masked.
This will be multiplied by number of timesteps divided by length of mask span to mask
approximately this percentage of all elements. However due to overlaps, the actual number
will be smaller (unless no_overlap is True).
mask_type (str): How to compute mask lengths. Options: [``static``, ``uniform``, ``normal``, ``poisson``].
``static``: Fixed size
``uniform``: Sample from uniform distribution [mask_other, mask_length*2]
``normal``: Sample from normal distribution with mean ``mask_length`` and stdev ``mask_other``.
``poisson``: Sample from possion distribution with lambda = ``mask_length``.
min_masks (int): Minimum number of masked spans.
no_overlap (bool): If false, will switch to an alternative recursive algorithm
that prevents spans from overlapping.
min_space (int): How many frames to keep unmasked between spans (Only used if no_overlap is True).
Returns:
(Tensor): The mask indices of dimension `[batch, frame]`.
"""
batch_size, frame = shape
mask = torch.full((batch_size, frame), False)
# add a random number for probabilistic rounding
all_num_mask = int(mask_prob * frame / float(mask_length) + torch.rand(1))
all_num_mask = max(min_masks, all_num_mask)
mask_idcs = []
for i in range(batch_size):
if padding_mask is not None:
sz = frame - padding_mask[i].long().sum().item()
# add a random number for probabilistic rounding
num_mask = int(mask_prob * sz / float(mask_length) + torch.rand(1))
num_mask = max(min_masks, num_mask)
else:
sz = frame
num_mask = all_num_mask
if mask_type == "static":
lengths = torch.full((num_mask,), mask_length)
elif mask_type == "uniform":
lengths = torch.randint(mask_other, mask_length * 2 + 1, size=(num_mask,))
elif mask_type == "normal":
lengths = torch.normal(mask_length, mask_other, size=(num_mask,))
lengths = torch.maximum(torch.ones(1), torch.round(lengths)).int()
elif mask_type == "poisson":
lengths = torch.poisson(mask_length, size=(num_mask,))
lengths = torch.round(lengths).int()
else:
raise Exception(f"unknown mask selection: {mask_type}")
if sum(lengths) == 0:
lengths[0] = min(mask_length, sz - 1)
if no_overlap:
mask_idc = []
def arrange(s, e, length, keep_length):
span_start = torch.randint(s, e - length, size=(1,))
mask_idc.extend(span_start + i for i in range(length))
new_parts = []
if span_start - s - min_space >= keep_length:
new_parts.append((s, span_start - min_space + 1))
if e - span_start - keep_length - min_space > keep_length:
new_parts.append((span_start + length + min_space, e))
return new_parts
parts = [(0, sz)]
min_length = min(lengths)
for length in sorted(lengths, reverse=True):
lens = torch.tensor([e - s for s, e in parts], dtype=torch.int)
lens[lens < length + min_space] = 0
l_sum = lens.sum()
if l_sum == 0:
break
probs = lens / l_sum
c = torch.distributions.categorical.Categorical(probs).sample()
s, e = parts.pop(c)
parts.extend(arrange(s, e, length, min_length))
mask_idc = torch.tensor(mask_idc)
else:
min_len = min(lengths)
if sz - min_len <= num_mask:
min_len = sz - num_mask - 1
mask_idc = torch.multinomial(torch.ones((sz - min_len,)), num_samples=num_mask, replacement=False)
mask_idc = torch.tensor(
[mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]
)
mask_idcs.append(torch.unique(mask_idc[mask_idc < sz]))
min_len = min([len(m) for m in mask_idcs])
for i, mask_idc in enumerate(mask_idcs):
if len(mask_idc) > min_len:
mask_idc = torch.index_select(
mask_idc,
0,
torch.multinomial(
torch.ones((mask_idc.shape[0],)),
num_samples=min_len,
replacement=False,
),
)
mask[i, mask_idc] = True
return mask
def _get_padding_mask(input: Tensor, lengths: Tensor) -> Tensor:
"""Generate the padding mask given the padded input and the lengths Tensors.
Args:
input (Tensor): The padded Tensor of dimension `[batch, max_len, frequency]`.
lengths (Tensor): The lengths Tensor of dimension `[batch,]`.
Returns:
(Tensor): The padding mask.
"""
batch_size, max_len, _ = input.shape
mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None]
return mask
class MaskGenerator(Module):
"""Generate the masks for masked prediction.
Args:
encoder_embed_dim (int): The dimension of the transformer embedding output.
mask_prob (float): Probability for each token to be chosen as start of the span to be masked.
This will be multiplied by number of timesteps divided by length of mask span to mask
approximately this percentage of all elements. However due to overlaps, the actual number
will be smaller (unless no_overlap is True).
mask_selection (str): How to choose the mask length.
Options: [``static``, ``uniform``, ``normal``, ``poisson``].
mask_other (float): Secondary mask argument (used for more complex distributions).
mask_length (int): The lengths of the mask.
no_mask_overlap (bool): Whether to allow masks to overlap.
mask_min_space (int): Minimum space between spans (if no overlap is enabled).
mask_channel_prob (float): The probability of replacing a feature with 0.
mask_channel_selection (str): How to choose the mask length for channel masking.
Options: [``static``, ``uniform``, ``normal``, ``poisson``].
mask_channel_other (float): Secondary mask argument for channel masking(used for more complex distributions).
mask_channel_length (int): Minimum space between spans (if no overlap is enabled) for channel masking.
no_mask_channel_overlap (bool): Whether to allow channel masks to overlap.
mask_channel_min_space (int): Minimum space between spans for channel masking(if no overlap is enabled).
"""
def __init__(
self,
encoder_embed_dim: int,
mask_prob: float,
mask_selection: str,
mask_other: float,
mask_length: int,
no_mask_overlap: bool,
mask_min_space: int,
mask_channel_prob: float,
mask_channel_selection: str,
mask_channel_other: float,
mask_channel_length: int,
no_mask_channel_overlap: bool,
mask_channel_min_space: int,
):
super().__init__()
self.mask_prob = mask_prob
self.mask_selection = mask_selection
self.mask_other = mask_other
self.mask_length = mask_length
self.no_mask_overlap = no_mask_overlap
self.mask_min_space = mask_min_space
self.mask_channel_prob = mask_channel_prob
self.mask_channel_selection = mask_channel_selection
self.mask_channel_other = mask_channel_other
self.mask_channel_length = mask_channel_length
self.no_mask_channel_overlap = no_mask_channel_overlap
self.mask_channel_min_space = mask_channel_min_space
self.mask_embedding = Parameter(torch.FloatTensor(encoder_embed_dim))
torch.nn.init.uniform_(self.mask_embedding)
def forward(self, x: Tensor, padding_mask: Optional[Tensor]) -> Tensor:
"""
Args:
x (Tensor): The encoded representations after feature extraction module.
padding_mask (Tensor or None): The padding mask of the same dimension as shape,
which will prevent masking padded elements.
Returns:
Tensor: The feature representations after masking.
Tensor: The generated mask indices.
"""
B, T, C = x.shape
if self.mask_prob > 0:
mask_indices = _compute_mask_indices(
(B, T),
padding_mask,
self.mask_prob,
self.mask_length,
self.mask_selection,
self.mask_other,
min_masks=2,
no_overlap=self.no_mask_overlap,
min_space=self.mask_min_space,
)
mask_indices = mask_indices.to(x.device)
x[mask_indices] = self.mask_embedding
else:
mask_indices = None
if self.mask_channel_prob > 0:
mask_channel_indices = _compute_mask_indices(
(B, C),
None,
self.mask_channel_prob,
self.mask_channel_length,
self.mask_channel_selection,
self.mask_channel_other,
no_overlap=self.no_mask_channel_overlap,
min_space=self.mask_channel_min_space,
)
mask_channel_indices = mask_channel_indices.to(x.device).unsqueeze(1).expand(-1, T, -1)
x[mask_channel_indices] = 0
return x, mask_indices
def _compute_logits(
proj_x: Tensor,
target: Tensor,
label_embeddings: Parameter,
) -> Tensor:
"""Compute the logits of the embeddings.
Args:
proj_x (Tensor): The projected masked representations of dimension `[batch, frame, final_dim]`.
target (Tensor): The target Tensor of dimension `[batch, frame, final_dim]`.
label_embeddings (Parameter): The trainable embeddings of target of dimension `[num_class, final_dim]`.
Returns:
(Tensor): The logits of the inputs.
"""
logit_temp = 0.1
pos = torch.index_select(label_embeddings, 0, target.long())
negs = label_embeddings.unsqueeze(1).expand(-1, proj_x.size(0), -1)
neg_is_pos = (pos == negs).all(-1)
pos = pos.unsqueeze(0)
targets = torch.cat([pos, negs], dim=0)
logits = torch.cosine_similarity(proj_x.float(), targets.float(), dim=-1).type_as(proj_x)
logits /= logit_temp
if neg_is_pos.any():
logits[1:][neg_is_pos] = float("-inf")
logits = logits.transpose(0, 1) # (num_x, num_cls+1)
return logits
class LogitGenerator(Module):
"""Generate the logits of masked and unmasked inputs.
Args:
encoder_embed_dim (int): The dimension of the transformer embedding output.
num_classes (int): The number of classes in the labels.
final_dim (int): Project final representations and targets to `final_dim`.
skip_masked (bool): If True, skip computing losses over masked frames.
skip_nomask (bool): If True, skip computing losses over unmasked frames.
"""
def __init__(
self,
encoder_embed_dim: int,
num_classes: int,
final_dim: int,
skip_masked: bool,
skip_nomask: bool,
):
super().__init__()
self.label_embeddings = Parameter(torch.FloatTensor(num_classes, final_dim))
torch.nn.init.uniform_(self.label_embeddings)
self.final_proj = torch.nn.Linear(encoder_embed_dim, final_dim)
self.skip_masked = skip_masked
self.skip_nomask = skip_nomask
def forward(self, x: Tensor, label: Tensor, mask_m: Tensor, mask_u: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
x (Tensor): The feature representation of the last transformer layer.
label (Tensor): The label Tensor of dimension `[batch, frame]`.
mask_m (Tensor): The masked indices of dimension `[batch, frame]`.
mask_u (Tensor): The unmasked indices of dimension `[batch, frame]`.
Returns:
Tensor: The logits of masked frames. Tensor of dimension `[masked_frame, final_dim]`.
Tensor: The logits of unmasked frames. Tensor of dimension `[unmasked_frame, final_dim]`.
"""
proj_x = self.final_proj(x)
if self.skip_masked:
logit_m = None
else:
proj_x_m = proj_x[mask_m]
label_m = label[mask_m]
logit_m = _compute_logits(proj_x_m, label_m, self.label_embeddings)
if self.skip_nomask:
logit_u = None
else:
proj_x_u = proj_x[mask_u]
label_u = label[mask_u]
logit_u = _compute_logits(proj_x_u, label_u, self.label_embeddings)
return logit_m, logit_u
......@@ -117,6 +117,89 @@ class Wav2Vec2Model(Module):
return x, lengths
class HuBERTPretrainModel(Module):
"""HuBERT pre-train model for training from scratch.
Note:
To build the model, please use one of the factory functions in
`[hubert_pretrain_base, hubert_pretrain_large, hubert_pretrain_xlarge]`.
Args:
feature_extractor (torch.nn.Module):
Feature extractor that extracts feature vectors from raw audio Tensor.
encoder (torch.nn.Module):
Encoder that converts the audio features into the sequence of probability
distribution (in negative log-likelihood) over labels.
mask_generator (torch.nn.Module):
Mask generator that generates the mask for masked prediction during the training.
logit_generator (torch.nn.Module):
Logit generator that predicts the logits of the masked and unmasked inputs.
"""
def __init__(
self,
wav2vec2: Wav2Vec2Model,
mask_generator: Module,
logit_generator: Module,
):
super().__init__()
self.wav2vec2 = wav2vec2
self.mask_generator = mask_generator
self.logit_generator = logit_generator
def forward(
self,
waveforms: Tensor,
labels: Tensor,
audio_lengths: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Compute the sequence of probability distribution over labels.
Args:
waveforms (Tensor): Audio tensor of dimension `[batch, frames]`.
labels (Tensor): Label for pre-training. A Tensor of dimension `[batch, frames]`.
audio_lengths (Tensor or None, optional):
Indicates the valid length of each audio in the batch.
Shape: `[batch, ]`.
When the ``waveforms`` contains audios with different durations,
by providing ``lengths`` argument, the model will compute
the corresponding valid output lengths and apply proper mask in
transformer attention layer.
If ``None``, it is assumed that all the audio in ``waveforms``
have valid length. Default: ``None``.
Returns:
(Tensor, Tensor):
Tensor
The masked sequences of probability distribution (in logit).
Shape: `(masked_frames, num labels)`.
Tensor
The unmasked sequence of probability distribution (in logit).
Shape: `(unmasked_frames, num labels)`.
"""
x, lengths = self.wav2vec2.feature_extractor(waveforms, audio_lengths)
if lengths is not None:
padding_mask = components._get_padding_mask(x, lengths)
else:
padding_mask = None
x, attention_mask = self.wav2vec2.encoder._preprocess(x, lengths)
x, mask = self.mask_generator(x, padding_mask)
x = self.wav2vec2.encoder.transformer(x, attention_mask=attention_mask)
if padding_mask:
mask_m = torch.logical_and(~padding_mask, mask)
mask_u = torch.logical_and(~padding_mask, ~mask_m)
else:
mask_m = mask
mask_u = ~mask_m
logit_m, logit_u = self.logit_generator(x, labels, mask_m, mask_u)
return logit_m, logit_u
def wav2vec2_model(
extractor_mode: str,
extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]],
......@@ -590,3 +673,476 @@ def hubert_xlarge(
encoder_layer_drop=encoder_layer_drop,
aux_num_out=aux_num_out,
)
def hubert_pretrain_model(
extractor_mode: str,
extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]],
extractor_conv_bias: bool,
encoder_embed_dim: int,
encoder_projection_dropout: float,
encoder_pos_conv_kernel: int,
encoder_pos_conv_groups: int,
encoder_num_layers: int,
encoder_num_heads: int,
encoder_attention_dropout: float,
encoder_ff_interm_features: int,
encoder_ff_interm_dropout: float,
encoder_dropout: float,
encoder_layer_norm_first: bool,
encoder_layer_drop: float,
mask_prob: float,
mask_selection: str,
mask_other: float,
mask_length: int,
no_mask_overlap: bool,
mask_min_space: int,
mask_channel_prob: float,
mask_channel_selection: str,
mask_channel_other: float,
mask_channel_length: int,
no_mask_channel_overlap: bool,
mask_channel_min_space: int,
skip_masked: bool,
skip_nomask: bool,
num_classes: int,
final_dim: int,
) -> HuBERTPretrainModel:
# Overriding the signature so that the return type is correct on Sphinx
"""hubert_pretrain_model(extractor_mode: str, extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]], extractor_conv_bias: bool, encoder_embed_dim: int, encoder_projection_dropout: float, encoder_pos_conv_kernel: int, encoder_pos_conv_groups: int, encoder_num_layers: int, encoder_num_heads: int, encoder_attention_dropout: float, encoder_ff_interm_features: int, encoder_ff_interm_dropout: float, encoder_dropout: float, encoder_layer_norm_first: bool, encoder_layer_drop: float, mask_prob: float, mask_selection: str, mask_other: float, mask_length: int, no_mask_overlap: bool, mask_min_space: int, mask_channel_prob: float, mask_channel_selection: str, mask_channel_other: float, mask_channel_length: int, no_mask_channel_overlap: bool, mask_channel_min_space: int, skip_masked: bool, skip_nomask: bool, num_classes: int, final_dim: int) -> torchaudio.models.HuBERTPretrainModel
Build a custom HuBERTPretrainModel for training from scratch
Note:
The "feature extractor" below corresponds to
`ConvFeatureExtractionModel <https://github.com/pytorch/fairseq/blob/dd3bd3c0497ae9a7ae7364404a6b0a4c501780b3/fairseq/models/wav2vec/wav2vec2.py#L736>`__
in the original ``fairseq`` implementation.
This is referred as "(convolutional) feature encoder" in the *wav2vec 2.0*
[:footcite:`baevski2020wav2vec`] paper.
The "encoder" below corresponds to `TransformerEncoder <https://github.com/pytorch/fairseq/blob/dd3bd3c0497ae9a7ae7364404a6b0a4c501780b3/fairseq/models/wav2vec/wav2vec2.py#L817>`__,
and this is referred as "Transformer" in the paper.
Args:
extractor_mode (str): Operation mode of feature extractor.
Valid values are ``"group_norm"`` or ``"layer_norm"``.
If ``"group_norm"``, then a single normalization is applied
in the first convolution block. Otherwise, all the convolution
blocks will have layer normalization.
This option corresponds to ``extractor_mode`` from ``fairseq``.
extractor_conv_layer_config (list of integer tuples or None):
Configuration of convolution layers in feature extractor.
List of convolution configuration,
i.e. ``[(output_channel, kernel_size, stride), ...]``
If ``None`` is provided, then the following default value is used.
.. code-block:: python
[
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 2, 2),
(512, 2, 2),
]
This option corresponds to ``conv_feature_layers`` from ``fairseq``.
extractor_conv_bias (bool):
Whether to include bias term to each convolution operation.
This option corresponds to ``conv_bias`` from ``fairseq``.
encoder_embed_dim (int):
The dimension of embedding in encoder.
This option corresponds to ``encoder_embed_dim`` from ``fairseq``.
encoder_projection_dropout (float):
The dropout probability applied after the input feature is projected
to ``encoder_embed_dim``.
This option corresponds to ``dropout_input`` from ``fairseq``.
encoder_pos_conv_kernel (int):
The kernel size of convolutional positional embeddings.
This option corresponds to ``conv_pos`` from ``fairseq``.
encoder_pos_conv_groups (int):
The number of groups of convolutional positional embeddings.
This option corresponds to ``conv_pos_groups`` from ``fairseq``.
encoder_num_layers (int):
The number of self attention layers in transformer block.
This option corresponds to ``encoder_layers`` from ``fairseq``.
encoder_num_heads (int):
The number of heads in self attention layers.
This option corresponds to ``encoder_attention_heads`` from ``fairseq``.
encoder_attention_dropout (float):
The dropout probability applied after softmax in self-attention layer.
This option corresponds to ``attention_dropout`` from ``fairseq``.
encoder_ff_interm_features (int):
The dimension of hidden features in feed forward layer.
This option corresponds to ``encoder_ffn_embed_dim`` from ``fairseq``.
encoder_ff_interm_dropout (float):
The dropout probability applied in feedforward layer.
This option correspinds to ``activation_dropout`` from ``fairseq``.
encoder_dropout (float):
The dropout probability applied at the end of feed forward layer.
This option corresponds to ``dropout`` from ``fairseq``.
encoder_layer_norm_first (bool):
Control the order of layer norm in transformer layer and each encoder layer.
If True, in transformer layer, layer norm is applied before features are fed
to encoder layers. In encoder layer, two layer norms are applied before and after
self attention.
If False, in transformer layer, layer norm is applied after features are fed
to encoder layers. In encoder layer, two layer norms are applied after self
attention, before and after feed forward.
This option corresponds to ``layer_norm_first`` from ``fairseq``.
encoder_layer_drop (float):
Probability to drop each encoder layer during training.
This option corresponds to ``layerdrop`` from ``fairseq``.
mask_prob (float):
Probability for each token to be chosen as start of the span to be masked. this will be multiplied by
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
However due to overlaps, the actual number will be smaller (unless no_overlap is True).
This option corresponds to ``mask_prob`` from ``fairseq``.
mask_selection (str):
How to choose the mask length. Options: [``static``, ``uniform``, ``normal``, ``poisson``].
This option corresponds to ``mask_selection`` from ``fairseq``.
mask_other (float):
Secondary mask argument (used for more complex distributions).
This option corresponds to ``mask_other`` from ``fairseq``.
mask_length (int):
The lengths of the mask.
This option corresponds to ``mask_length`` from ``fairseq``.
no_mask_overlap (bool):
Whether to allow masks to overlap.
This option corresponds to ``no_mask_overlap`` from ``fairseq``.
mask_min_space (int):
Minimum space between spans (if no overlap is enabled).
This option corresponds to ``mask_min_space`` from ``fairseq``.
mask_channel_prob: (float):
The probability of replacing a feature with 0.
This option corresponds to ``mask_channel_prob`` from ``fairseq``.
mask_channel_selection (str):
How to choose the mask length for channel masking. Options: [``static``, ``uniform``, ``normal``, ``poisson``].
This option corresponds to ``mask_channel_selection`` from ``fairseq``.
mask_channel_other (float):
Secondary mask argument for channel masking(used for more complex distributions).
This option corresponds to ``mask_channel_other`` from ``fairseq``.
mask_channel_length (int):
Minimum space between spans (if no overlap is enabled) for channel masking.
This option corresponds to ``mask_channel_length`` from ``fairseq``.
no_mask_channel_overlap (bool):
Whether to allow channel masks to overlap.
This option corresponds to ``no_mask_channel_overlap`` from ``fairseq``.
mask_channel_min_space (int):
Minimum space between spans for channel masking(if no overlap is enabled).
This option corresponds to ``mask_channel_min_space`` from ``fairseq``.
skip_masked (bool):
If True, skip computing losses over masked frames.
This option corresponds to ``skip_masked`` from ``fairseq``.
skip_nomask (bool):
If True, skip computing losses over unmasked frames.
This option corresponds to ``skip_nomask`` from ``fairseq``.
num_classes (int):
The number of classes in the labels.
final_dim (int):
Project final representations and targets to `final_dim`.
This option corresponds to ``final_dim`` from ``fairseq``.
Returns:
HuBERTPretrainModel:
The resulting model.
""" # noqa: E501
if extractor_conv_layer_config is None:
extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
feature_extractor = components._get_feature_extractor(
extractor_mode, extractor_conv_layer_config, extractor_conv_bias
)
encoder = components._get_encoder(
in_features=extractor_conv_layer_config[-1][0],
embed_dim=encoder_embed_dim,
dropout_input=encoder_projection_dropout,
pos_conv_kernel=encoder_pos_conv_kernel,
pos_conv_groups=encoder_pos_conv_groups,
num_layers=encoder_num_layers,
num_heads=encoder_num_heads,
attention_dropout=encoder_attention_dropout,
ff_interm_features=encoder_ff_interm_features,
ff_interm_dropout=encoder_ff_interm_dropout,
dropout=encoder_dropout,
layer_norm_first=encoder_layer_norm_first,
layer_drop=encoder_layer_drop,
)
wav2vec2 = Wav2Vec2Model(feature_extractor, encoder)
mask_generator = components.MaskGenerator(
encoder_embed_dim,
mask_prob,
mask_selection,
mask_other,
mask_length,
no_mask_overlap,
mask_min_space,
mask_channel_prob,
mask_channel_selection,
mask_channel_other,
mask_channel_length,
no_mask_channel_overlap,
mask_channel_min_space,
)
logit_generator = components.LogitGenerator(
encoder_embed_dim,
num_classes,
final_dim,
skip_masked,
skip_nomask,
)
return HuBERTPretrainModel(wav2vec2=wav2vec2, mask_generator=mask_generator, logit_generator=logit_generator)
def hubert_pretrain_base(
encoder_projection_dropout: float = 0.1,
encoder_attention_dropout: float = 0.1,
encoder_ff_interm_dropout: float = 0.0,
encoder_dropout: float = 0.1,
encoder_layer_drop: float = 0.05,
num_classes: int = 100,
) -> HuBERTPretrainModel:
# Overriding the signature so that the return type is correct on Sphinx
"""hubert_pretrain_base(encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.1, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.1, encoder_layer_drop: float = 0.05, num_classes: int = 100) -> torchaudio.models.HuBERTPretrainModel
Build HuBERTPretrainModel model with "base" architecture from *HuBERT* [:footcite:`hsu2021hubert`]
Args:
encoder_projection_dropout (float):
See :py:func:`hubert_pretrain_model`.
encoder_attention_dropout (float):
See :py:func:`hubert_pretrain_model`.
encoder_ff_interm_dropout (float):
See :py:func:`hubert_pretrain_model`.
encoder_dropout (float):
See :py:func:`hubert_pretrain_model`.
encoder_layer_drop (float):
See :py:func:`hubert_pretrain_model`.
num_classes (int, optional):
See :py:func:`hubert_pretrain_model`.
Returns:
HuBERTPretrainModel:
The resulting model.
""" # noqa: E501
return hubert_pretrain_model(
extractor_mode="group_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=False,
encoder_embed_dim=768,
encoder_projection_dropout=encoder_projection_dropout,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=12,
encoder_num_heads=12,
encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_features=3072,
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=encoder_dropout,
encoder_layer_norm_first=False,
encoder_layer_drop=encoder_layer_drop,
mask_prob=0.80,
mask_selection="static",
mask_other=0.0,
mask_length=10,
no_mask_overlap=False,
mask_min_space=1,
mask_channel_prob=0.0,
mask_channel_selection="static",
mask_channel_other=0.0,
mask_channel_length=10,
no_mask_channel_overlap=False,
mask_channel_min_space=1,
skip_masked=False,
skip_nomask=False,
num_classes=num_classes,
final_dim=256,
)
def hubert_pretrain_large(
encoder_projection_dropout: float = 0.0,
encoder_attention_dropout: float = 0.0,
encoder_ff_interm_dropout: float = 0.0,
encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.0,
) -> HuBERTPretrainModel:
# Overriding the signature so that the return type is correct on Sphinx
"""hubert_pretrain_large(encoder_projection_dropout: float = 0.0, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.0) -> torchaudio.models.HuBERTPretrainModel
Build HuBERTPretrainModel model for pre-training with "large" architecture from *HuBERT* [:footcite:`hsu2021hubert`]
Args:
encoder_projection_dropout (float):
See :py:func:`hubert_pretrain_model`.
encoder_attention_dropout (float):
See :py:func:`hubert_pretrain_model`.
encoder_ff_interm_dropout (float):
See :py:func:`hubert_pretrain_model`.
encoder_dropout (float):
See :py:func:`hubert_pretrain_model`.
encoder_layer_drop (float):
See :py:func:`hubert_pretrain_model`.
Returns:
HuBERTPretrainModel:
The resulting model.
""" # noqa: E501
return hubert_pretrain_model(
extractor_mode="layer_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=False,
encoder_embed_dim=1024,
encoder_projection_dropout=encoder_projection_dropout,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=24,
encoder_num_heads=16,
encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_features=4096,
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=encoder_dropout,
encoder_layer_norm_first=True,
encoder_layer_drop=encoder_layer_drop,
mask_prob=0.80,
mask_selection="static",
mask_other=0.0,
mask_length=10,
no_mask_overlap=False,
mask_min_space=1,
mask_channel_prob=0.0,
mask_channel_selection="static",
mask_channel_other=0.0,
mask_channel_length=10,
no_mask_channel_overlap=False,
mask_channel_min_space=1,
skip_masked=False,
skip_nomask=False,
num_classes=500,
final_dim=768,
)
def hubert_pretrain_xlarge(
encoder_projection_dropout: float = 0.0,
encoder_attention_dropout: float = 0.0,
encoder_ff_interm_dropout: float = 0.0,
encoder_dropout: float = 0.0,
encoder_layer_drop: float = 0.0,
) -> HuBERTPretrainModel:
# Overriding the signature so that the return type is correct on Sphinx
"""hubert_pretrain_xlarge(encoder_projection_dropout: float = 0.0, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.0) -> torchaudio.models.HuBERTPretrainModel
Build HuBERTPretrainModel model for pre-training with "extra large" architecture from *HuBERT* [:footcite:`hsu2021hubert`]
Args:
encoder_projection_dropout (float):
See :py:func:`hubert_pretrain_model`.
encoder_attention_dropout (float):
See :py:func:`hubert_pretrain_model`.
encoder_ff_interm_dropout (float):
See :py:func:`hubert_pretrain_model`.
encoder_dropout (float):
See :py:func:`hubert_pretrain_model`.
encoder_layer_drop (float):
See :py:func:`hubert_pretrain_model`.
Returns:
HuBERTPretrainModel:
The resulting model.
""" # noqa: E501
return hubert_pretrain_model(
extractor_mode="layer_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=False,
encoder_embed_dim=1280,
encoder_projection_dropout=encoder_projection_dropout,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=48,
encoder_num_heads=16,
encoder_attention_dropout=encoder_attention_dropout,
encoder_ff_interm_features=5120,
encoder_ff_interm_dropout=encoder_ff_interm_dropout,
encoder_dropout=encoder_dropout,
encoder_layer_norm_first=True,
encoder_layer_drop=encoder_layer_drop,
mask_prob=0.80,
mask_selection="static",
mask_other=0.0,
mask_length=10,
no_mask_overlap=False,
mask_min_space=1,
mask_channel_prob=0.0,
mask_channel_selection="static",
mask_channel_other=0.0,
mask_channel_length=10,
no_mask_channel_overlap=False,
mask_channel_min_space=1,
skip_masked=False,
skip_nomask=False,
num_classes=500,
final_dim=1024,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment