Unverified Commit e6886a4d authored by moto's avatar moto Committed by GitHub
Browse files

Add wav2vec2.0 model (#1529)

- TorchScript-able `Wav2Vec2Model` class
- Factory functions for three configurations presented in the paper 
  - `wav2vec2_base`
  - `wav2vec2_large`
  - `wav2vec2_large_lv60k`
parent 838e1e0a
......@@ -33,6 +33,29 @@ The models subpackage contains definitions of models for addressing common audio
.. automethod:: forward
:hidden:`Wav2Vec2.0`
~~~~~~~~~~~~~~~~~~~~
Model
-----
.. autoclass:: Wav2Vec2Model
.. automethod:: extract_features
.. automethod:: forward
Factory Functions
-----------------
.. autofunction:: wav2vec2_base
.. autofunction:: wav2vec2_large
.. autofunction:: wav2vec2_large_lv60k
:hidden:`WaveRNN`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -19,6 +19,7 @@ from .case_utils import (
skipIfNoKaldi,
skipIfNoSox,
skipIfRocm,
skipIfNoQengine,
)
from .wav_utils import (
get_wav_data,
......@@ -49,6 +50,7 @@ __all__ = [
'skipIfNoSox',
'skipIfNoSoxBackend',
'skipIfRocm',
'skipIfNoQengine',
'get_wav_data',
'normalize_wav',
'load_wav',
......
......@@ -109,3 +109,7 @@ skipIfNoSox = unittest.skipIf(not is_sox_available(), reason='Sox not available'
skipIfNoKaldi = unittest.skipIf(not is_kaldi_available(), reason='Kaldi not available')
skipIfRocm = unittest.skipIf(os.getenv('TORCHAUDIO_TEST_WITH_ROCM', '0') == '1',
reason="test doesn't currently work on the ROCm stack")
skipIfNoQengine = unittest.skipIf(
'fbgemm' not in torch.backends.quantized.supported_engines,
reason="`fbgemm` is not available."
)
import io
import torch
import torch.nn.functional as F
from torchaudio.models.wav2vec2 import (
wav2vec2_base,
wav2vec2_large,
wav2vec2_large_lv60k,
)
from torchaudio_unittest.common_utils import (
TorchaudioTestCase,
skipIfNoQengine,
skipIfNoCuda,
)
from parameterized import parameterized
factory_funcs = parameterized.expand([
(wav2vec2_base, ),
(wav2vec2_large, ),
(wav2vec2_large_lv60k, ),
])
class TestWav2Vec2Model(TorchaudioTestCase):
def _smoke_test(self, device, dtype):
model = wav2vec2_base(num_out=32)
model = model.to(device=device, dtype=dtype)
model = model.eval()
torch.manual_seed(0)
batch_size, num_frames = 3, 1024
waveforms = torch.randn(
batch_size, num_frames, device=device, dtype=dtype)
lengths = torch.randint(
low=0, high=num_frames, size=[batch_size, ], device=device)
model(waveforms, lengths)
@parameterized.expand([(torch.float32, ), (torch.float64, )])
def test_cpu_smoke_test(self, dtype):
self._smoke_test(torch.device('cpu'), dtype)
@parameterized.expand([(torch.float32, ), (torch.float64, )])
@skipIfNoCuda
def test_cuda_smoke_test(self, dtype):
self._smoke_test(torch.device('cuda'), dtype)
@factory_funcs
def test_feature_extractor_smoke_test(self, factory_func):
"""`extract_features` method does not fail"""
batch_size, num_frames = 3, 1024
model = factory_func(num_out=32).eval()
torch.manual_seed(0)
waveforms = torch.randn(batch_size, num_frames)
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ])
features, lengths = model.extract_features(waveforms, lengths)
assert features.ndim == 3
assert features.shape[0] == batch_size
assert lengths.shape == torch.Size([batch_size])
@factory_funcs
def test_batch_consistency(self, factory_func):
"""Results from sigle process and batched process should be reasonably close
"""
batch_size, max_frames = 5, 5 * 1024
model = factory_func(num_out=32).eval()
torch.manual_seed(0)
waveforms = torch.randn(batch_size, max_frames)
input_lengths = torch.tensor([i * 3200 for i in range(1, 6)])
# Batch process with lengths
batch_logits, output_lengths = model(waveforms, input_lengths)
for i in range(batch_size):
# Par-sample process without feeding length
single_logit, _ = model(waveforms[i:i + 1, :input_lengths[i]], None)
batch_logit = batch_logits[i:i + 1, :output_lengths[i]]
# Convert to probability so that it's easier to interpretate the diff
single_prob = F.softmax(single_logit, dim=2)
batch_prob = F.softmax(batch_logit, dim=2)
# We allow max atol=0.005 -> 0.5%
self.assertEqual(single_prob, batch_prob, atol=0.005, rtol=0)
@factory_funcs
def test_zero_length(self, factory_func):
"""Passing zero length should not fail"""
model = factory_func(num_out=32).eval()
torch.manual_seed(0)
batch_size = 3
waveforms = torch.randn(batch_size, 1024)
input_lengths = torch.zeros(batch_size)
_, output_lengths = model(waveforms, input_lengths)
self.assertEqual(torch.zeros_like(output_lengths), output_lengths)
@factory_funcs
def test_torchscript(self, factory_func):
"""Wav2Vec2Model should be scriptable"""
batch_size, num_frames = 3, 1024
model = factory_func(num_out=32).eval()
torch.manual_seed(0)
waveforms = torch.randn(batch_size, num_frames)
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ])
ref_out, ref_len = model(waveforms, lengths)
# TODO: put this in a common method of Mixin class.
# Script
scripted = torch.jit.script(model)
buffer_ = io.BytesIO()
torch.jit.save(scripted, buffer_)
buffer_.seek(0)
scripted = torch.jit.load(buffer_)
hyp_out, hyp_len = scripted(waveforms, lengths)
self.assertEqual(hyp_out, ref_out)
self.assertEqual(hyp_len, ref_len)
@factory_funcs
@skipIfNoQengine
def test_quantize(self, factory_func):
"""Wav2Vec2Model should support basic quantization"""
batch_size, num_frames = 3, 1024
model = factory_func(num_out=32).eval()
# Remove the weight normalization forward hook
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
quantized = torch.quantization.quantize_dynamic(
model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
# A lazy way to check that Modules are different
assert str(quantized) != str(model), "Dynamic quantization did not modify the module."
torch.manual_seed(0)
waveforms = torch.randn(batch_size, num_frames)
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ])
_, _ = quantized(waveforms, lengths)
@factory_funcs
@skipIfNoQengine
def test_quantize_torchscript(self, factory_func):
"""Quantized Wav2Vec2Model should be scriptable"""
batch_size, num_frames = 3, 1024
model = factory_func(num_out=32).eval()
# Remove the weight normalization forward hook
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
quantized = torch.quantization.quantize_dynamic(
model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
# A lazy way to check that Modules are different
assert str(quantized) != str(model), "Dynamic quantization did not modify the module."
torch.manual_seed(0)
waveforms = torch.randn(batch_size, num_frames)
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ])
ref_out, ref_len = quantized(waveforms, lengths)
# Script
scripted = torch.jit.script(quantized)
buffer_ = io.BytesIO()
torch.jit.save(scripted, buffer_)
buffer_.seek(0)
scripted = torch.jit.load(buffer_)
hyp_out, hyp_len = scripted(waveforms, lengths)
self.assertEqual(hyp_out, ref_out)
self.assertEqual(hyp_len, ref_len)
......@@ -2,10 +2,21 @@ from .wav2letter import Wav2Letter
from .wavernn import WaveRNN
from .conv_tasnet import ConvTasNet
from .deepspeech import DeepSpeech
from .wav2vec2 import (
Wav2Vec2Model,
wav2vec2_base,
wav2vec2_large,
wav2vec2_large_lv60k,
)
__all__ = [
'Wav2Letter',
'WaveRNN',
'ConvTasNet',
'DeepSpeech',
'Wav2Vec2Model',
'wav2vec2_base',
'wav2vec2_large',
'wav2vec2_large_lv60k',
]
from .model import (
Wav2Vec2Model,
wav2vec2_base,
wav2vec2_large,
wav2vec2_large_lv60k,
)
__all__ = [
'Wav2Vec2Model',
'wav2vec2_base',
'wav2vec2_large',
'wav2vec2_large_lv60k',
]
This diff is collapsed.
from typing import Optional, Tuple, List
from torch import Tensor
from torch.nn import Module
from . import components
class Wav2Vec2Model(Module):
"""Model used in wav2vec2.0 paper. [1]
Note:
To build the model, please use one of the factory functions.
Args:
feature_extractor (torch.nn.Module):
Feature extractor that extracts feature vectors from raw audio Tensor.
encoder (torch.nn.Module):
Encoder that converts the audio features into the sequence of probability
distribution (in negative log-likelihood) over labels.
Reference:
- wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli
https://arxiv.org/abs/2006.11477
"""
def __init__(
self,
feature_extractor: Module,
encoder: Module,
):
super().__init__()
self.feature_extractor = feature_extractor
self.encoder = encoder
def extract_features(
self,
waveforms: Tensor,
lengths: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Extract feature vectors from raw waveforms
Args:
waveforms (Tensor): Audio tensor of shape ``(batch, frames)``.
lengths (Tensor, optional):
Indicates the valid length of each audio sample in the batch.
Shape: ``(batch, )``.
Returns:
Tensor:
Feature vectors.
Shape: ``(batch, frames, feature dimention)``
Tensor, optional:
Indicates the valid length of each feature in the batch, computed
based on the given ``lengths`` argument.
Shape: ``(batch, )``.
"""
return self.feature_extractor(waveforms, lengths)
def forward(
self,
waveforms: Tensor,
lengths: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Compute the sequence of probability distribution over labels.
Args:
waveforms (Tensor): Audio tensor of shape ``(batch, frames)``.
lengths (Tensor, optional):
Indicates the valid length of each audio sample in the batch.
Shape: ``(batch, )``.
Returns:
Tensor:
The sequences of probability distribution (in logit) over labels.
Shape: ``(batch, frames, num labels)``.
Tensor, optional:
Indicates the valid length of each feature in the batch, computed
based on the given ``lengths`` argument.
Shape: ``(batch, )``.
"""
x, lengths = self.feature_extractor(waveforms, lengths)
return self.encoder(x, lengths), lengths
def _get_model(
extractor_mode: str,
extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]],
extractor_conv_bias: bool,
encoder_embed_dim: int,
encoder_projection_dropout: float,
encoder_pos_conv_kernel: int,
encoder_pos_conv_groups: int,
encoder_num_layers: int,
encoder_num_heads: int,
encoder_attention_dropout: float,
encoder_ff_interm_features: int,
encoder_ff_interm_dropout: float,
encoder_dropout: float,
encoder_layer_norm_first: bool,
encoder_layer_drop: float,
encoder_num_out: int,
) -> Wav2Vec2Model:
if extractor_conv_layer_config is None:
extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
feature_extractor = components._get_feature_extractor(
extractor_mode, extractor_conv_layer_config, extractor_conv_bias)
encoder = components._get_encoder(
in_features=extractor_conv_layer_config[-1][0],
embed_dim=encoder_embed_dim,
dropout_input=encoder_projection_dropout,
pos_conv_kernel=encoder_pos_conv_kernel,
pos_conv_groups=encoder_pos_conv_groups,
num_layers=encoder_num_layers,
num_heads=encoder_num_heads,
attention_dropout=encoder_attention_dropout,
ff_interm_features=encoder_ff_interm_features,
ff_interm_dropout=encoder_ff_interm_dropout,
dropout=encoder_dropout,
layer_norm_first=encoder_layer_norm_first,
layer_drop=encoder_layer_drop,
num_out=encoder_num_out,
)
return Wav2Vec2Model(feature_extractor, encoder)
def wav2vec2_base(num_out: int) -> Wav2Vec2Model:
"""Build wav2vec2.0 model with **Base** configuration. [1]
Args:
num_out: int
The number of output labels.
Returns:
Wav2Vec2Model: The resulting model.
Reference:
- wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli
https://arxiv.org/abs/2006.11477
"""
return _get_model(
extractor_mode="group_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=False,
encoder_embed_dim=768,
encoder_projection_dropout=0.1,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=12,
encoder_num_heads=12,
encoder_attention_dropout=0.1,
encoder_ff_interm_features=3072,
encoder_ff_interm_dropout=0.1,
encoder_dropout=0.1,
encoder_layer_norm_first=False,
encoder_layer_drop=0.1,
encoder_num_out=num_out,
)
def wav2vec2_large(num_out: int) -> Wav2Vec2Model:
"""Build wav2vec2.0 model with **Large** configuration. [1]
Args:
num_out: int
The number of output labels.
Returns:
Wav2Vec2Model: The resulting model.
Reference:
- wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli
https://arxiv.org/abs/2006.11477
"""
return _get_model(
extractor_mode="group_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=False,
encoder_embed_dim=1024,
encoder_projection_dropout=0.1,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=24,
encoder_num_heads=16,
encoder_attention_dropout=0.1,
encoder_ff_interm_features=4096,
encoder_ff_interm_dropout=0.1,
encoder_dropout=0.1,
encoder_layer_norm_first=False,
encoder_layer_drop=0.1,
encoder_num_out=num_out,
)
def wav2vec2_large_lv60k(num_out: int) -> Wav2Vec2Model:
"""Build wav2vec2.0 model with **Large LV-60k** configuration. [1]
Args:
num_out: int
The number of output labels.
Returns:
Wav2Vec2Model: The resulting model.
Reference:
- wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli
https://arxiv.org/abs/2006.11477
"""
return _get_model(
extractor_mode="layer_norm",
extractor_conv_layer_config=None,
extractor_conv_bias=True,
encoder_embed_dim=1024,
encoder_projection_dropout=0.1,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=24,
encoder_num_heads=16,
encoder_attention_dropout=0.0,
encoder_ff_interm_features=4096,
encoder_ff_interm_dropout=0.1,
encoder_dropout=0.0,
encoder_layer_norm_first=True,
encoder_layer_drop=0.1,
encoder_num_out=num_out,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment