Commit 8bde6a54 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Add conformer wav2vec2 pretrain model (#2827)

Summary:
modeled after [paper](https://arxiv.org/pdf/2110.07313.pdf) and internal flow f288347302

internal comparison tests: D40080919

Pull Request resolved: https://github.com/pytorch/audio/pull/2827

Reviewed By: nateanl

Differential Revision: D41569046

Pulled By: carolineechen

fbshipit-source-id: 43c5313074af05972d93da55b2029c746b75c380
parent b0795ebe
...@@ -33,6 +33,13 @@ ConvEmformer ...@@ -33,6 +33,13 @@ ConvEmformer
.. automethod:: infer .. automethod:: infer
ConformerWav2Vec2PretrainModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: ConformerWav2Vec2PretrainModel
.. automethod:: forward
conformer_wav2vec2_model conformer_wav2vec2_model
~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -42,3 +49,18 @@ conformer_wav2vec2_base ...@@ -42,3 +49,18 @@ conformer_wav2vec2_base
~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: conformer_wav2vec2_base .. autofunction:: conformer_wav2vec2_base
conformer_wav2vec2_pretrain_model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: conformer_wav2vec2_pretrain_model
conformer_wav2vec2_pretrain_base
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: conformer_wav2vec2_pretrain_base
conformer_wav2vec2_pretrain_large
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: conformer_wav2vec2_pretrain_large
import torch
from parameterized import parameterized
from torchaudio.prototype.models import (
conformer_wav2vec2_base,
conformer_wav2vec2_pretrain_base,
conformer_wav2vec2_pretrain_large,
)
from torchaudio_unittest.common_utils import nested_params, skipIfNoCuda, torch_script, TorchaudioTestCase
class TestConformerWav2Vec2(TorchaudioTestCase):
def _smoke_test(self, model, device, dtype):
model = model.to(device=device, dtype=dtype)
model = model.eval()
batch_size, num_frames, in_features = 3, 1024, 64
features = torch.randn(batch_size, num_frames, in_features, device=device, dtype=dtype)
lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
device=device,
)
model(features, lengths)
@parameterized.expand([(torch.float32,), (torch.float64,)])
def test_cpu_smoke_test(self, dtype):
model = conformer_wav2vec2_base()
self._smoke_test(model, torch.device("cpu"), dtype)
@parameterized.expand([(torch.float32,), (torch.float64,)])
@skipIfNoCuda
def test_cuda_smoke_test(self, dtype):
model = conformer_wav2vec2_base()
self._smoke_test(model, torch.device("cuda"), dtype)
@nested_params(
[conformer_wav2vec2_pretrain_base, conformer_wav2vec2_pretrain_large],
[torch.float32, torch.float64],
)
def test_pretrain_cpu_smoke_test(self, model, dtype):
model = model()
self._smoke_test(model, torch.device("cpu"), dtype)
@nested_params(
[conformer_wav2vec2_pretrain_base, conformer_wav2vec2_pretrain_large],
[torch.float32, torch.float64],
)
@skipIfNoCuda
def test_pretrain_cuda_smoke_test(self, model, dtype):
model = model()
self._smoke_test(model, torch.device("cuda"), dtype)
def test_extract_feature(self):
model = conformer_wav2vec2_base()
model.eval()
batch_size, num_frames, in_features = 3, 1024, 64
num_layers = len(model.encoder.conformer)
features = torch.randn(batch_size, num_frames, in_features)
lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
)
all_features, lengths_ = model.extract_features(features, lengths, num_layers=None)
assert len(all_features) == num_layers
for feats in all_features:
assert feats.ndim == 3
assert feats.shape[0] == batch_size
assert lengths_.shape == torch.Size([batch_size])
for l in range(1, num_layers + 1):
feats, lengths_ = model.extract_features(features, lengths, num_layers=l)
assert len(feats) == l
for i in range(l):
self.assertEqual(all_features[i], feats[i])
assert lengths_.shape == torch.Size([batch_size])
def test_zero_length(self):
model = conformer_wav2vec2_base()
model.eval()
batch_size, num_frames, in_features = 3, 1024, 64
features = torch.randn(batch_size, num_frames, in_features)
input_lengths = torch.zeros(batch_size)
_, output_lengths = model(features, input_lengths)
self.assertEqual(torch.zeros_like(output_lengths), output_lengths)
_, output_lengths = model.extract_features(features, input_lengths)
self.assertEqual(torch.zeros_like(output_lengths), output_lengths)
def test_torchscript_consistency(self):
model = conformer_wav2vec2_base()
model.eval()
batch_size, num_frames, in_features = 3, 1024, 64
features = torch.randn(batch_size, num_frames, in_features)
lengths = torch.randint(
low=0,
high=num_frames,
size=[
batch_size,
],
)
ref_out, ref_len = model(features, lengths)
scripted = torch_script(model)
hyp_out, hyp_len = scripted(features, lengths)
self.assertEqual(hyp_out, ref_out)
self.assertEqual(hyp_len, ref_len)
import torch import torch
from parameterized import parameterized from parameterized import parameterized
from torchaudio.prototype.models import conformer_wav2vec2_base, emformer_hubert_base from torchaudio.prototype.models import conformer_wav2vec2_base, conformer_wav2vec2_pretrain_base, emformer_hubert_base
from torchaudio_unittest.common_utils import nested_params, skipIfNoCuda, torch_script, TorchaudioTestCase from torchaudio_unittest.common_utils import nested_params, skipIfNoCuda, torch_script, TorchaudioTestCase
...@@ -23,7 +23,7 @@ class TestSSLModel(TorchaudioTestCase): ...@@ -23,7 +23,7 @@ class TestSSLModel(TorchaudioTestCase):
model(features, lengths) model(features, lengths)
@nested_params( @nested_params(
[(conformer_wav2vec2_base, 64), (emformer_hubert_base, 80)], [(conformer_wav2vec2_base, 64), (conformer_wav2vec2_pretrain_base, 64), (emformer_hubert_base, 80)],
[torch.float32, torch.float64], [torch.float32, torch.float64],
) )
def test_cpu_smoke_test(self, model_feature_dim, dtype): def test_cpu_smoke_test(self, model_feature_dim, dtype):
...@@ -32,7 +32,7 @@ class TestSSLModel(TorchaudioTestCase): ...@@ -32,7 +32,7 @@ class TestSSLModel(TorchaudioTestCase):
self._smoke_test(model, feature_dim, torch.device("cpu"), dtype) self._smoke_test(model, feature_dim, torch.device("cpu"), dtype)
@nested_params( @nested_params(
[(conformer_wav2vec2_base, 64), (emformer_hubert_base, 80)], [(conformer_wav2vec2_base, 64), (conformer_wav2vec2_pretrain_base, 64), (emformer_hubert_base, 80)],
[torch.float32, torch.float64], [torch.float32, torch.float64],
) )
@skipIfNoCuda @skipIfNoCuda
......
from ._conformer_wav2vec2 import conformer_wav2vec2_base, conformer_wav2vec2_model from ._conformer_wav2vec2 import (
conformer_wav2vec2_base,
conformer_wav2vec2_model,
conformer_wav2vec2_pretrain_base,
conformer_wav2vec2_pretrain_large,
conformer_wav2vec2_pretrain_model,
ConformerWav2Vec2PretrainModel,
)
from ._emformer_hubert import emformer_hubert_base, emformer_hubert_model from ._emformer_hubert import emformer_hubert_base, emformer_hubert_model
from .conv_emformer import ConvEmformer from .conv_emformer import ConvEmformer
from .rnnt import conformer_rnnt_base, conformer_rnnt_model from .rnnt import conformer_rnnt_base, conformer_rnnt_model
...@@ -9,6 +16,10 @@ __all__ = [ ...@@ -9,6 +16,10 @@ __all__ = [
"ConvEmformer", "ConvEmformer",
"conformer_wav2vec2_model", "conformer_wav2vec2_model",
"conformer_wav2vec2_base", "conformer_wav2vec2_base",
"conformer_wav2vec2_pretrain_model",
"conformer_wav2vec2_pretrain_base",
"conformer_wav2vec2_pretrain_large",
"ConformerWav2Vec2PretrainModel",
"emformer_hubert_base", "emformer_hubert_base",
"emformer_hubert_model", "emformer_hubert_model",
] ]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment