Unverified Commit e6886a4d authored by moto's avatar moto Committed by GitHub
Browse files

Add wav2vec2.0 model (#1529)

- TorchScript-able `Wav2Vec2Model` class
- Factory functions for three configurations presented in the paper 
  - `wav2vec2_base`
  - `wav2vec2_large`
  - `wav2vec2_large_lv60k`
parent 838e1e0a
...@@ -33,6 +33,29 @@ The models subpackage contains definitions of models for addressing common audio ...@@ -33,6 +33,29 @@ The models subpackage contains definitions of models for addressing common audio
.. automethod:: forward .. automethod:: forward
:hidden:`Wav2Vec2.0`
~~~~~~~~~~~~~~~~~~~~
Model
-----
.. autoclass:: Wav2Vec2Model
.. automethod:: extract_features
.. automethod:: forward
Factory Functions
-----------------
.. autofunction:: wav2vec2_base
.. autofunction:: wav2vec2_large
.. autofunction:: wav2vec2_large_lv60k
:hidden:`WaveRNN` :hidden:`WaveRNN`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -19,6 +19,7 @@ from .case_utils import ( ...@@ -19,6 +19,7 @@ from .case_utils import (
skipIfNoKaldi, skipIfNoKaldi,
skipIfNoSox, skipIfNoSox,
skipIfRocm, skipIfRocm,
skipIfNoQengine,
) )
from .wav_utils import ( from .wav_utils import (
get_wav_data, get_wav_data,
...@@ -49,6 +50,7 @@ __all__ = [ ...@@ -49,6 +50,7 @@ __all__ = [
'skipIfNoSox', 'skipIfNoSox',
'skipIfNoSoxBackend', 'skipIfNoSoxBackend',
'skipIfRocm', 'skipIfRocm',
'skipIfNoQengine',
'get_wav_data', 'get_wav_data',
'normalize_wav', 'normalize_wav',
'load_wav', 'load_wav',
......
...@@ -109,3 +109,7 @@ skipIfNoSox = unittest.skipIf(not is_sox_available(), reason='Sox not available' ...@@ -109,3 +109,7 @@ skipIfNoSox = unittest.skipIf(not is_sox_available(), reason='Sox not available'
skipIfNoKaldi = unittest.skipIf(not is_kaldi_available(), reason='Kaldi not available') skipIfNoKaldi = unittest.skipIf(not is_kaldi_available(), reason='Kaldi not available')
skipIfRocm = unittest.skipIf(os.getenv('TORCHAUDIO_TEST_WITH_ROCM', '0') == '1', skipIfRocm = unittest.skipIf(os.getenv('TORCHAUDIO_TEST_WITH_ROCM', '0') == '1',
reason="test doesn't currently work on the ROCm stack") reason="test doesn't currently work on the ROCm stack")
skipIfNoQengine = unittest.skipIf(
'fbgemm' not in torch.backends.quantized.supported_engines,
reason="`fbgemm` is not available."
)
import io
import torch
import torch.nn.functional as F
from torchaudio.models.wav2vec2 import (
wav2vec2_base,
wav2vec2_large,
wav2vec2_large_lv60k,
)
from torchaudio_unittest.common_utils import (
TorchaudioTestCase,
skipIfNoQengine,
skipIfNoCuda,
)
from parameterized import parameterized
factory_funcs = parameterized.expand([
(wav2vec2_base, ),
(wav2vec2_large, ),
(wav2vec2_large_lv60k, ),
])
class TestWav2Vec2Model(TorchaudioTestCase):
def _smoke_test(self, device, dtype):
model = wav2vec2_base(num_out=32)
model = model.to(device=device, dtype=dtype)
model = model.eval()
torch.manual_seed(0)
batch_size, num_frames = 3, 1024
waveforms = torch.randn(
batch_size, num_frames, device=device, dtype=dtype)
lengths = torch.randint(
low=0, high=num_frames, size=[batch_size, ], device=device)
model(waveforms, lengths)
@parameterized.expand([(torch.float32, ), (torch.float64, )])
def test_cpu_smoke_test(self, dtype):
self._smoke_test(torch.device('cpu'), dtype)
@parameterized.expand([(torch.float32, ), (torch.float64, )])
@skipIfNoCuda
def test_cuda_smoke_test(self, dtype):
self._smoke_test(torch.device('cuda'), dtype)
@factory_funcs
def test_feature_extractor_smoke_test(self, factory_func):
"""`extract_features` method does not fail"""
batch_size, num_frames = 3, 1024
model = factory_func(num_out=32).eval()
torch.manual_seed(0)
waveforms = torch.randn(batch_size, num_frames)
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ])
features, lengths = model.extract_features(waveforms, lengths)
assert features.ndim == 3
assert features.shape[0] == batch_size
assert lengths.shape == torch.Size([batch_size])
@factory_funcs
def test_batch_consistency(self, factory_func):
"""Results from sigle process and batched process should be reasonably close
"""
batch_size, max_frames = 5, 5 * 1024
model = factory_func(num_out=32).eval()
torch.manual_seed(0)
waveforms = torch.randn(batch_size, max_frames)
input_lengths = torch.tensor([i * 3200 for i in range(1, 6)])
# Batch process with lengths
batch_logits, output_lengths = model(waveforms, input_lengths)
for i in range(batch_size):
# Par-sample process without feeding length
single_logit, _ = model(waveforms[i:i + 1, :input_lengths[i]], None)
batch_logit = batch_logits[i:i + 1, :output_lengths[i]]
# Convert to probability so that it's easier to interpretate the diff
single_prob = F.softmax(single_logit, dim=2)
batch_prob = F.softmax(batch_logit, dim=2)
# We allow max atol=0.005 -> 0.5%
self.assertEqual(single_prob, batch_prob, atol=0.005, rtol=0)
@factory_funcs
def test_zero_length(self, factory_func):
"""Passing zero length should not fail"""
model = factory_func(num_out=32).eval()
torch.manual_seed(0)
batch_size = 3
waveforms = torch.randn(batch_size, 1024)
input_lengths = torch.zeros(batch_size)
_, output_lengths = model(waveforms, input_lengths)
self.assertEqual(torch.zeros_like(output_lengths), output_lengths)
@factory_funcs
def test_torchscript(self, factory_func):
"""Wav2Vec2Model should be scriptable"""
batch_size, num_frames = 3, 1024
model = factory_func(num_out=32).eval()
torch.manual_seed(0)
waveforms = torch.randn(batch_size, num_frames)
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ])
ref_out, ref_len = model(waveforms, lengths)
# TODO: put this in a common method of Mixin class.
# Script
scripted = torch.jit.script(model)
buffer_ = io.BytesIO()
torch.jit.save(scripted, buffer_)
buffer_.seek(0)
scripted = torch.jit.load(buffer_)
hyp_out, hyp_len = scripted(waveforms, lengths)
self.assertEqual(hyp_out, ref_out)
self.assertEqual(hyp_len, ref_len)
@factory_funcs
@skipIfNoQengine
def test_quantize(self, factory_func):
"""Wav2Vec2Model should support basic quantization"""
batch_size, num_frames = 3, 1024
model = factory_func(num_out=32).eval()
# Remove the weight normalization forward hook
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
quantized = torch.quantization.quantize_dynamic(
model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
# A lazy way to check that Modules are different
assert str(quantized) != str(model), "Dynamic quantization did not modify the module."
torch.manual_seed(0)
waveforms = torch.randn(batch_size, num_frames)
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ])
_, _ = quantized(waveforms, lengths)
@factory_funcs
@skipIfNoQengine
def test_quantize_torchscript(self, factory_func):
"""Quantized Wav2Vec2Model should be scriptable"""
batch_size, num_frames = 3, 1024
model = factory_func(num_out=32).eval()
# Remove the weight normalization forward hook
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
quantized = torch.quantization.quantize_dynamic(
model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
# A lazy way to check that Modules are different
assert str(quantized) != str(model), "Dynamic quantization did not modify the module."
torch.manual_seed(0)
waveforms = torch.randn(batch_size, num_frames)
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ])
ref_out, ref_len = quantized(waveforms, lengths)
# Script
scripted = torch.jit.script(quantized)
buffer_ = io.BytesIO()
torch.jit.save(scripted, buffer_)
buffer_.seek(0)
scripted = torch.jit.load(buffer_)
hyp_out, hyp_len = scripted(waveforms, lengths)
self.assertEqual(hyp_out, ref_out)
self.assertEqual(hyp_len, ref_len)
...@@ -2,10 +2,21 @@ from .wav2letter import Wav2Letter ...@@ -2,10 +2,21 @@ from .wav2letter import Wav2Letter
from .wavernn import WaveRNN from .wavernn import WaveRNN
from .conv_tasnet import ConvTasNet from .conv_tasnet import ConvTasNet
from .deepspeech import DeepSpeech from .deepspeech import DeepSpeech
from .wav2vec2 import (
Wav2Vec2Model,
wav2vec2_base,
wav2vec2_large,
wav2vec2_large_lv60k,
)
__all__ = [ __all__ = [
'Wav2Letter', 'Wav2Letter',
'WaveRNN', 'WaveRNN',
'ConvTasNet', 'ConvTasNet',
'DeepSpeech', 'DeepSpeech',
'Wav2Vec2Model',
'wav2vec2_base',
'wav2vec2_large',
'wav2vec2_large_lv60k',
] ]
from .model import (
Wav2Vec2Model,
wav2vec2_base,
wav2vec2_large,
wav2vec2_large_lv60k,
)
__all__ = [
'Wav2Vec2Model',
'wav2vec2_base',
'wav2vec2_large',
'wav2vec2_large_lv60k',
]
import logging
from typing import Optional, Tuple, List
import torch
from torch import Tensor, nn
from torch.nn import Module
_LG = logging.getLogger(__name__)
class LayerNorm(nn.LayerNorm):
"""Layer norm with transpose"""
def forward(self, input: Tensor) -> Tensor:
x = input.transpose(-2, -1)
x = nn.functional.layer_norm(
x, self.normalized_shape, self.weight, self.bias, self.eps)
x = x.transpose(-2, -1)
return x
class ConvLayerBlock(Module):
"""Convolution unit of FeatureExtractor"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int,
bias: bool,
layer_norm: Optional[Module],
):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
self.layer_norm = layer_norm
self.conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
bias=bias,
)
def forward(
self,
x: Tensor,
length: Optional[Tensor],
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
x (Tensor): Shape: ``[batch, in_channels, in_frame]``.
length (Tensor, optional): Shape ``[batch, ]``.
Returns:
Tensor: Shape ``[batch, out_channels, out_frames]``.
Optional[Tensor]: Shape ``[batch, ]``.
"""
x = self.conv(x)
if self.layer_norm is not None:
x = self.layer_norm(x)
x = nn.functional.gelu(x)
if length is not None:
length = torch.div(length - self.kernel_size, self.stride, rounding_mode='floor') + 1
# When input length is 0, the resulting length can be negative. So fix it here.
length = torch.max(torch.zeros_like(length), length)
return x, length
class FeatureExtractor(Module):
"""Extract features from audio
Args:
conv_layers (nn.ModuleList):
convolution layers
"""
def __init__(
self,
conv_layers: nn.ModuleList,
):
super().__init__()
self.conv_layers = conv_layers
def forward(
self,
x: Tensor,
length: Optional[Tensor],
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
x (Tensor):
Input Tensor representing a batch of audio,
shape: ``[batch, time]``.
length (Tensor, optional):
Valid length of each input sample. shape: ``[batch, ]``.
Returns:
Tensor:
The resulting feature, shape: ``[batch, frame, feature]``
Optional[Tensor]:
Valid length of each output sample. shape: ``[batch, ]``.
"""
if x.ndim != 2:
raise ValueError(
"Expected the input Tensor to be 2D (batch, time), "
"but received {list(x.shape)}")
x = x.unsqueeze(1) # (batch, channel==1, frame)
for layer in self.conv_layers:
x, length = layer(x, length) # (batch, feature, frame)
x = x.transpose(1, 2) # (batch, frame, feature)
return x, length
class FeatureProjection(Module):
"""Layer that connects FeatureExtractor and Encoder
Projects features to encoder dimension.
Args:
in_features (int): Input feature dim.
out_features (int): Output feature dim.
dropout (float): Dropout probability.
"""
def __init__(
self,
in_features: int,
out_features: int,
dropout: float,
):
super().__init__()
self.layer_norm = nn.LayerNorm(in_features)
self.projection = nn.Linear(in_features, out_features,)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
"""
Args:
x (Tensor):
Feature Tensor. shape: ``[batch, frame, in_feature]``
Returns:
Tensor: Projected features. ``[batch, frame, out_feature]``.
"""
x = self.layer_norm(x)
x = self.projection(x)
x = self.dropout(x)
return x
class ConvolutionalPositionalEmbedding(Module):
"""Positional embedding which is placed at the beginning of Transformer.
Args:
embed_dim (int): Feature dimension of the input Tensor.
kernel_size (int): The number of frames to be use.
groups (int): The number of groups in feature dimensions.
"""
def __init__(
self,
embed_dim: int,
kernel_size: int,
groups: int,
):
super().__init__()
self.embed_dim = embed_dim
self.conv = nn.Conv1d(
in_channels=embed_dim,
out_channels=embed_dim,
kernel_size=kernel_size,
padding=kernel_size // 2,
groups=groups,
)
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.num_remove: int = 1 if kernel_size % 2 == 0 else 0
def __prepare_scriptable__(self):
for hook in self.conv._forward_pre_hooks.values():
# The hook we want to remove is an instance of WeightNorm class, so
# normally we would do `if isinstance(...)` but this class is not accessible
# because of shadowing, so we check the module name directly.
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
if (
hook.__module__ == 'torch.nn.utils.weight_norm' and
hook.__class__.__name__ == 'WeightNorm'
):
_LG.warning('Removing weight_norm from %s', self.__class__.__name__)
torch.nn.utils.remove_weight_norm(self.conv)
return self
def forward(self, x):
"""
Args:
x (Tensor): shape ``[batch, frame, feature]``.
Returns:
Tensor: The resulting feature. Shape ``[batch, frame, feature]``.
"""
x = x.transpose(-2, -1)
x = self.conv(x)
if self.num_remove > 0:
x = x[..., :-self.num_remove]
x = torch.nn.functional.gelu(x)
x = x.transpose(-2, -1)
return x
class SelfAttention(Module):
"""Multihead Self Attention module
Args:
embed_dim (int): Total dimension of the model.
num_heads (int): The number of heads.
dropout (float, optional):
Dropout probabiliry on attn_output_weights. Default: ``0.0``
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
):
super().__init__()
head_dim = embed_dim // num_heads
if head_dim * num_heads != embed_dim:
raise ValueError(f"`embed_dim ({embed_dim})` is not divisible by `num_heads ({num_heads})`")
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = head_dim
self.scaling = self.head_dim ** -0.5
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
def forward(
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Args:
x (Tensor): shape: ``[batch_size, sequence_length, embed_dim]``.
attention_mask (Tensor, optional):
shape: ``[batch_size, 1, sequence_length, sequence_length]``
Returns:
Tensor: The resulting tensor. shape: ``[batch, sequence_length, embed_dim]``
"""
if x.ndim != 3 or x.shape[2] != self.embed_dim:
raise ValueError(
f"The expected input shape is (batch, sequence, embed_dim=={self.embed_dim}). "
f"Found {x.shape}."
)
batch_size, length, embed_dim = x.size()
if attention_mask is not None:
shape_ = (batch_size, 1, length, length)
if attention_mask.size() != shape_:
raise ValueError(
f"The expected attention mask shape is {shape_}. "
f"Found {attention_mask.size()}."
)
shape = (batch_size, length, self.num_heads, self.head_dim)
q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
k = self.k_proj(x).view(*shape).permute(0, 2, 3, 1) # B, nH, Hd, L
v = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
weights = self.scaling * (q @ k) # B, nH, L, L
if attention_mask is not None:
weights += attention_mask
weights = torch.nn.functional.softmax(weights, dim=-1)
weights = torch.nn.functional.dropout(weights, p=self.dropout, training=self.training)
output = weights @ v # B, nH, L, Hd
output = output.transpose(2, 1).reshape(batch_size, length, embed_dim)
output = self.out_proj(output)
return output
class FeedForward(Module):
"""Layer that follows attention layer in encoder layer.
"""
def __init__(
self,
io_features: int,
intermediate_features: int,
intermediate_dropout: float,
output_dropout: float,
):
super().__init__()
self.intermediate_dense = nn.Linear(io_features, intermediate_features)
self.intermediate_dropout = nn.Dropout(intermediate_dropout)
self.output_dense = nn.Linear(intermediate_features, io_features)
self.output_dropout = nn.Dropout(output_dropout)
def forward(self, x):
"""
Args:
x (Tensor): shape: ``(batch, sequence_length, io_features)``
Returns:
x (Tensor): shape: ``(batch, sequence_length, io_features)``
"""
x = self.intermediate_dense(x)
x = torch.nn.functional.gelu(x)
x = self.intermediate_dropout(x)
x = self.output_dense(x)
x = self.output_dropout(x)
return x
class EncoderLayer(Module):
"""A layer unit in encoder. Combines multihead self attention and feed forward.
"""
def __init__(
self,
attention: Module,
dropout: float,
layer_norm_first: bool,
feed_forward: Module,
):
super().__init__()
self.attention = attention
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(attention.embed_dim)
self.layer_norm_first = layer_norm_first
self.feed_forward = feed_forward
self.final_layer_norm = nn.LayerNorm(attention.embed_dim)
def forward(
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
):
"""
Args:
x (Tensor): shape: ``(batch, sequence_length, embed_dim)``
attention_mask (Tensor, optional):
shape: ``(batch, 1, sequence_length, sequence_length)``
"""
residual = x
if self.layer_norm_first:
x = self.layer_norm(x)
x = self.attention(x, attention_mask)
x = self.dropout(x)
x = residual + x
if self.layer_norm_first:
x = x + self.feed_forward(self.final_layer_norm(x))
else:
x = self.layer_norm(x)
x = self.final_layer_norm(x + self.feed_forward(x))
return x
class Transformer(Module):
def __init__(
self,
pos_conv_embed: Module,
dropout: float,
layers: Module,
layer_norm_first: bool,
layer_drop: float,
):
super().__init__()
self.pos_conv_embed = pos_conv_embed
self.layer_norm = nn.LayerNorm(pos_conv_embed.embed_dim)
self.layer_norm_first = layer_norm_first
self.layer_drop = layer_drop
self.dropout = nn.Dropout(dropout)
self.layers = layers
def forward(
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
):
x = x + self.pos_conv_embed(x)
if self.layer_norm_first:
x = self.layer_norm(x)
x = self.dropout(x)
for layer in self.layers:
if not (self.training and torch.rand(1).item() <= self.layer_drop):
x = layer(x, attention_mask)
if not self.layer_norm_first:
x = self.layer_norm(x)
return x
class Encoder(Module):
def __init__(
self,
feature_projection: Module,
transformer: Module,
readout: Module,
):
super().__init__()
self.feature_projection = feature_projection
self.transformer = transformer
self.readout = readout
def forward(
self,
features: Tensor,
lengths: Optional[Tensor] = None,
) -> Tensor:
x = self.feature_projection(features)
mask: Optional[Tensor] = None
if lengths is not None:
batch_size, max_len, _ = x.shape
# create mask for padded elements and zero-out them
mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None]
x[mask] = 0.0
# extend the mask to attention shape and set weight
mask = -10000.0 * mask[:, None, None, :].to(dtype=features.dtype)
mask = mask.expand(batch_size, 1, max_len, max_len)
x = self.transformer(x, attention_mask=mask)
x = self.readout(x)
return x
################################################################################
def _get_feature_extractor(
norm_mode: str,
shapes: List[Tuple[int, int, int]],
bias: bool,
) -> FeatureExtractor:
"""
Args:
norm_mode (str):
Either "group_norm" or "layer_norm".
If "group_norm", then a single normalization is applied
in the first convolution block. Otherwise, all the convolution
blocks will have layer normalization.
This option corresponds to "extractor_mode" from fairseq.
Expected values are "group_norm" for Base arch, and
"layer_norm" for Large arch.
shapes (list of tuple of int):
Configuration of convolution layers. List of convolution configuration,
i.e. ``[(output_channel, kernel_size, stride), ...]``
This option corresponds to "conv_feature_layers" from fairseq.
Expected values are
``[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2``
for all the architectures.
bias (bool):
Whether to include bias term to each convolution operation.
This option corresponds to "conv_bias" from fairseq.
Expected values are False for Base arch, and True for Large arch.
See Also:
* Original implementation
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L666-L733
* "extractor_mode"
- Def and base:
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L38-L45
- Large:
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L52
* "conv_feature_layers"
- Def, base and large:
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L94-L100
* "conv_bias"
- Def and base:
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L101-L103
- Large:
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L61
"""
assert norm_mode in ["group_norm", "layer_norm"]
blocks = []
in_channels = 1
for i, (out_channels, kernel_size, stride) in enumerate(shapes):
normalization = None
if norm_mode == "group_norm" and i == 0:
normalization = nn.GroupNorm(
num_groups=out_channels,
num_channels=out_channels,
affine=True,
)
elif norm_mode == "layer_norm":
normalization = LayerNorm(
normalized_shape=out_channels,
elementwise_affine=True,
)
blocks.append(
ConvLayerBlock(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
bias=bias,
layer_norm=normalization,
)
)
in_channels = out_channels
return FeatureExtractor(nn.ModuleList(blocks))
def _get_encoder(
in_features: int,
embed_dim: int,
dropout_input: float,
pos_conv_kernel: int,
pos_conv_groups: int,
num_layers: int,
num_heads: int,
attention_dropout: float,
ff_interm_features: int,
ff_interm_dropout: float,
dropout: float,
layer_norm_first: bool,
layer_drop: float,
num_out: int,
) -> Encoder:
"""
Args:
in_features (int): The number of input features.
embed_dim (int):
The dimension of embedding.
This option corresponds to "encoder_embed_dim" from fairseq.
Expected values are 768 for Base arch, and 1024 for Large arch.
dropout_input (float):
The dropout probability applied after the input feature is projected
to ``embed_dim``.
This option corresponds to "dropout_input" from fairseq.
Expected values are 0.1 for both Base and Large arch.
pos_conv_kernel (int):
The kernel size of convolutional positional embeddings.
This option corresponds to "conv_pos" from fairseq.
Expected values are 128 for both Base and Large arch.
pos_conv_groups (int):
The number of groups of convolutional positional embeddings.
This option corresponds to "conv_pos_groups" from fairseq.
Expected values are 16 for both Base and Large arch.
num_layers (int):
The number of self attention layers in transformer block.
This option corresponds to "encoder_layers" from fairseq.
Expected values are 12 for Base and 24 for Large arch.
num_heads (int):
The number of heads in self attention layers.
This option corresponds to "encoder_attention_heads" from fairseq.
Expected values are 12 for Base and 16 for Large arch.
attention_dropout (float):
The dropout probability applied after softmax in self-attention layer.
This option corresponds to "attention_dropout" from fairseq.
Expected values are 0.1 for Base and 0.0 for Large arch.
ff_interm_features (int):
The dimension of hidden features in feed forward layer.
This option corresponds to "encoder_ffn_embed_dim" from fairseq.
Expected values are 3072 for Base and 4096 for Large arch.
ff_interm_dropout (float):
The dropout probability applied in feedforward layer.
This option correspinds to "activation_dropout" from fairseq.
Expected values are 0.1 for both Base and Large arch.
dropout (float):
The dropout probability applied at the end of feed forward layer.
This option corresponds to "dropout" from fairseq.
Expected values are 0.1 for Base and 0.0 for Large arch.
layer_norm_first (bool):
Control the order of layer norm in transformer layer and each encoder layer.
If True, in transformer layer, layer norm is applied before features are fed
to encoder layers. In encoder layer, two layer norms are applied before and after
self attention.
If False, in transformer layer, layer norm is applied after features are fed
to encoder layers. In encoder layer, two layer norms are applied after self
attention, before and after feed forward.
This option corresponds to "layer_norm_first" from fairseq.
Expected values are False for Base and True for Large arch.
layer_drop (float):
Probability to drop each encoder layer during training.
This option corresponds to "layerdrop" from fairseq.
Expected values are 0.1 for both Base and Large arch.
num_out (int):
The dimension of the output. The number of labels.
See Also:
* "encoder_embed_dim"
- Def and base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L49-L51
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L64
* "dropout_input"
- Def, base and large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L75-L78
* "conv_pos"
- Def, base and large
NOTE: The description is wrong.
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L204-L207
- Usage
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L756
* "conv_pos_groups"
- Def, base and large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L208-L211
* "encoder_layers"
- Def and base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L46-L48
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L63
* "encoder_attention_heads"
- Def and base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L55-L57
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L66
* "attention_dropout"
- Def and base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L66-L68
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L60
* "encoder_ffn_embed_dim"
- Def and base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L52-L54
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L65
* "activation_dropout"
- Def
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L69-L71
- Base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L55
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L55
* "dropout"
- Def and base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L63-L65
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L59
* "layer_norm_first"
- Def and base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L91-L93
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L53
* "layerdrop"
- Def
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L72-L74
- Base
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L54
- Large
https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L54
"""
feature_projection = FeatureProjection(in_features, embed_dim, dropout_input)
pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups)
# Original impl
# https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782
encoder_layers = nn.ModuleList()
for _ in range(num_layers):
attention = SelfAttention(
embed_dim=embed_dim,
num_heads=num_heads,
dropout=attention_dropout,
)
feed_forward = FeedForward(
io_features=embed_dim,
intermediate_features=ff_interm_features,
intermediate_dropout=ff_interm_dropout,
output_dropout=dropout,
)
encoder_layers.append(
EncoderLayer(
attention=attention,
dropout=dropout,
layer_norm_first=layer_norm_first,
feed_forward=feed_forward,
)
)
transformer = Transformer(
pos_conv_embed=pos_conv,
dropout=dropout,
layers=encoder_layers,
layer_norm_first=not layer_norm_first,
layer_drop=layer_drop,
)
readout = nn.Linear(
in_features=embed_dim,
out_features=num_out,
)
return Encoder(feature_projection, transformer, readout)
from typing import Optional, Tuple, List
from torch import Tensor
from torch.nn import Module
from . import components
class Wav2Vec2Model(Module):
"""Model used in wav2vec2.0 paper. [1]
Note:
To build the model, please use one of the factory functions.
Args:
feature_extractor (torch.nn.Module):
Feature extractor that extracts feature vectors from raw audio Tensor.
encoder (torch.nn.Module):
Encoder that converts the audio features into the sequence of probability
distribution (in negative log-likelihood) over labels.
Reference:
- wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli
https://arxiv.org/abs/2006.11477
"""
def __init__(
self,
feature_extractor: Module,
encoder: Module,
):
super().__init__()
self.feature_extractor = feature_extractor
self.encoder = encoder
def extract_features(
self,
waveforms: Tensor,
lengths: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Extract feature vectors from raw waveforms
Args:
waveforms (Tensor): Audio tensor of shape ``(batch, frames)``.
lengths (Tensor, optional):
Indicates the valid length of each audio sample in the batch.
Shape: ``(batch, )``.
Returns:
Tensor:
Feature vectors.
Shape: ``(batch, frames, feature dimention)``
Tensor, optional:
Indicates the valid length of each feature in the batch, computed
based on the given ``lengths`` argument.
Shape: ``(batch, )``.
"""
return self.feature_extractor(waveforms, lengths)
def forward(
self,
waveforms: Tensor,
lengths: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Compute the sequence of probability distribution over labels.
Args:
waveforms (Tensor): Audio tensor of shape ``(batch, frames)``.
lengths (Tensor, optional):
Indicates the valid length of each audio sample in the batch.
Shape: ``(batch, )``.
Returns:
Tensor:
The sequences of probability distribution (in logit) over labels.
Shape: ``(batch, frames, num labels)``.
Tensor, optional:
Indicates the valid length of each feature in the batch, computed
based on the given ``lengths`` argument.
Shape: ``(batch, )``.
"""
x, lengths = self.feature_extractor(waveforms, lengths)
return self.encoder(x, lengths), lengths
def _get_model(
extractor_mode: str,
extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]],
extractor_conv_bias: bool,
encoder_embed_dim: int,
encoder_projection_dropout: float,
encoder_pos_conv_kernel: int,
encoder_pos_conv_groups: int,
encoder_num_layers: int,
encoder_num_heads: int,
encoder_attention_dropout: float,
encoder_ff_interm_features: int,
encoder_ff_interm_dropout: float,
encoder_dropout: float,
encoder_layer_norm_first: bool,
encoder_layer_drop: float,
encoder_num_out: int,
) -> Wav2Vec2Model:
if extractor_conv_layer_config is None:
extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
feature_extractor = components._get_feature_extractor(
extractor_mode, extractor_conv_layer_config, extractor_conv_bias)
encoder = components._get_encoder(
in_features=extractor_conv_layer_config[-1][0],
embed_dim=encoder_embed_dim,
dropout_input=encoder_projection_dropout,
pos_conv_kernel=encoder_pos_conv_kernel,
pos_conv_groups=encoder_pos_conv_groups,
num_layers=encoder_num_layers,
num_heads=encoder_num_heads,
attention_dropout=encoder_attention_dropout,
ff_interm_features=encoder_ff_interm_features,
ff_interm_dropout=encoder_ff_interm_dropout,
dropout=encoder_dropout,
layer_norm_first=encoder_layer_norm_first,
layer_drop=encoder_layer_drop,
num_out=encoder_num_out,
)
return Wav2Vec2Model(feature_extractor, encoder)
def wav2vec2_base(num_out: int) -> Wav2Vec2Model:
"""Build wav2vec2.0 model with **Base** configuration. [1]
Args:
num_out: int
The number of output labels.
Returns:
Wav2Vec2Model: The resulting model.
Reference:
- wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli
https://arxiv.org/abs/2006.11477
"""
return _get_model(
extractor_mode="group_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=False,
encoder_embed_dim=768,
encoder_projection_dropout=0.1,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=12,
encoder_num_heads=12,
encoder_attention_dropout=0.1,
encoder_ff_interm_features=3072,
encoder_ff_interm_dropout=0.1,
encoder_dropout=0.1,
encoder_layer_norm_first=False,
encoder_layer_drop=0.1,
encoder_num_out=num_out,
)
def wav2vec2_large(num_out: int) -> Wav2Vec2Model:
"""Build wav2vec2.0 model with **Large** configuration. [1]
Args:
num_out: int
The number of output labels.
Returns:
Wav2Vec2Model: The resulting model.
Reference:
- wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli
https://arxiv.org/abs/2006.11477
"""
return _get_model(
extractor_mode="group_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=False,
encoder_embed_dim=1024,
encoder_projection_dropout=0.1,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=24,
encoder_num_heads=16,
encoder_attention_dropout=0.1,
encoder_ff_interm_features=4096,
encoder_ff_interm_dropout=0.1,
encoder_dropout=0.1,
encoder_layer_norm_first=False,
encoder_layer_drop=0.1,
encoder_num_out=num_out,
)
def wav2vec2_large_lv60k(num_out: int) -> Wav2Vec2Model:
"""Build wav2vec2.0 model with **Large LV-60k** configuration. [1]
Args:
num_out: int
The number of output labels.
Returns:
Wav2Vec2Model: The resulting model.
Reference:
- wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli
https://arxiv.org/abs/2006.11477
"""
return _get_model(
extractor_mode="layer_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=True,
encoder_embed_dim=1024,
encoder_projection_dropout=0.1,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=24,
encoder_num_heads=16,
encoder_attention_dropout=0.0,
encoder_ff_interm_features=4096,
encoder_ff_interm_dropout=0.1,
encoder_dropout=0.0,
encoder_layer_norm_first=True,
encoder_layer_drop=0.1,
encoder_num_out=num_out,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment