Commit aca61bc0 authored by Andreas Floros's avatar Andreas Floros Committed by Facebook GitHub Bot
Browse files

Add layer normalization to wav2vec2 large+ pretrained models (#2873)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2873

The original fairseq implementation had an extra layer normalization
preprocessings for large/xlarge models.

https://github.com/facebookresearch/fairseq/blob/fcca32258c8e8bcc9f9890bf4714fa2f96b6b3e1/fairseq/data/audio/hubert_dataset.py#L355-L357

This commit modifies the pre-trained model bundle to include this
preprocessing to the impacted pre-trained models listed bellow.
For the sake of keeping the interface identical to the other models,
since the additional preprocessing is rather simple, the returned
pre-trained model instance is modified ot include the preprocess,
instead of adding a method for preprocessing.

- WAV2VEC2_LARGE_LV60K
- WAV2VEC2_ASR_LARGE_LV60K_10M
- WAV2VEC2_ASR_LARGE_LV60K_100H
- WAV2VEC2_ASR_LARGE_LV60K_960H
- WAV2VEC2_XLSR53
- HUBERT_LARGE
- HUBERT_XLARGE
- HUBERT_ASR_LARGE
- HUBERT_ASR_XLARGE
- WAVLM_LARGE

Reviewed By: nateanl

Differential Revision: D41520183

fbshipit-source-id: 83d72fe692e8b9fc25df144deb4ca946fcd09615
parent fc0720b4
from dataclasses import dataclass
from typing import Any, Dict, Tuple
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch import Tensor
from torch.nn import functional as F, Module
from torchaudio._internal import load_state_dict_from_url
from torchaudio.models import wav2vec2_model, Wav2Vec2Model, wavlm_model
......@@ -11,6 +13,31 @@ from . import utils
__all__ = []
class _Wav2Vec2Model(Module):
"""Wrapper class for :py:class:`~torchaudio.models.Wav2Vec2Model`.
This is used for layer normalization at the input
"""
def __init__(self, model: Wav2Vec2Model):
super().__init__()
self.model = model
def forward(self, waveforms: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
waveforms = F.layer_norm(waveforms, waveforms.shape)
return self.model(waveforms, lengths)
@torch.jit.export
def extract_features(
self,
waveforms: Tensor,
lengths: Optional[Tensor] = None,
num_layers: Optional[int] = None,
) -> Tuple[List[Tensor], Optional[Tensor]]:
waveforms = F.layer_norm(waveforms, waveforms.shape)
return self.model.extract_features(waveforms, lengths, num_layers)
@dataclass
class Wav2Vec2Bundle:
"""Data class that bundles associated information to use pretrained :py:class:`~torchaudio.models.Wav2Vec2Model`.
......@@ -45,6 +72,7 @@ class Wav2Vec2Bundle:
_path: str
_params: Dict[str, Any]
_sample_rate: float
_normalize_waveform: bool
@property
def sample_rate(self) -> float:
......@@ -60,7 +88,7 @@ class Wav2Vec2Bundle:
state_dict = load_state_dict_from_url(url, **dl_kwargs)
return state_dict
def get_model(self, *, dl_kwargs=None) -> Wav2Vec2Model:
def get_model(self, *, dl_kwargs=None) -> Module:
"""Construct the model and load the pretrained weight.
The weight file is downloaded from the internet and cached with
......@@ -68,6 +96,24 @@ class Wav2Vec2Bundle:
Args:
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
Returns:
Variation of :py:class:`~torchaudio.models.Wav2Vec2Model`.
For the models listed below, an additional layer normalization is performed on the input.
For all other models, a :py:class:`~torchaudio.models.Wav2Vec2Model` instance is returned.
- WAV2VEC2_LARGE_LV60K
- WAV2VEC2_ASR_LARGE_LV60K_10M
- WAV2VEC2_ASR_LARGE_LV60K_100H
- WAV2VEC2_ASR_LARGE_LV60K_960H
- WAV2VEC2_XLSR53
- HUBERT_LARGE
- HUBERT_XLARGE
- HUBERT_ASR_LARGE
- HUBERT_ASR_XLARGE
- WAVLM_LARGE
"""
model_type = self._params.pop("model_type", None)
if model_type == "WavLM":
......@@ -75,6 +121,8 @@ class Wav2Vec2Bundle:
else:
model = wav2vec2_model(**self._params)
model.load_state_dict(self._get_state_dict(dl_kwargs))
if self._normalize_waveform:
model = _Wav2Vec2Model(model)
model.eval()
return model
......@@ -196,6 +244,7 @@ WAV2VEC2_BASE = Wav2Vec2Bundle(
"aux_num_out": None,
},
_sample_rate=16000,
_normalize_waveform=False,
)
WAV2VEC2_BASE.__doc__ = """Wav2vec 2.0 model ("base" architecture),
pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
......@@ -239,6 +288,7 @@ WAV2VEC2_ASR_BASE_10M = Wav2Vec2ASRBundle(
},
_labels=utils._get_en_labels(),
_sample_rate=16000,
_normalize_waveform=False,
)
WAV2VEC2_ASR_BASE_10M.__doc__ = """Wav2vec 2.0 model ("base" architecture with an extra linear module),
pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
......@@ -284,6 +334,7 @@ WAV2VEC2_ASR_BASE_100H = Wav2Vec2ASRBundle(
},
_labels=utils._get_en_labels(),
_sample_rate=16000,
_normalize_waveform=False,
)
WAV2VEC2_ASR_BASE_100H.__doc__ = """Wav2vec 2.0 model ("base" architecture with an extra linear module),
......@@ -329,6 +380,7 @@ WAV2VEC2_ASR_BASE_960H = Wav2Vec2ASRBundle(
},
_labels=utils._get_en_labels(),
_sample_rate=16000,
_normalize_waveform=False,
)
WAV2VEC2_ASR_BASE_960H.__doc__ = """Wav2vec 2.0 model ("base" architecture with an extra linear module),
pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
......@@ -372,6 +424,7 @@ WAV2VEC2_LARGE = Wav2Vec2Bundle(
"aux_num_out": None,
},
_sample_rate=16000,
_normalize_waveform=False,
)
WAV2VEC2_LARGE.__doc__ = """Wav2vec 2.0 model ("large" architecture),
pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
......@@ -415,6 +468,7 @@ WAV2VEC2_ASR_LARGE_10M = Wav2Vec2ASRBundle(
},
_labels=utils._get_en_labels(),
_sample_rate=16000,
_normalize_waveform=False,
)
WAV2VEC2_ASR_LARGE_10M.__doc__ = """Wav2vec 2.0 model ("large" architecture with an extra linear module),
pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
......@@ -460,6 +514,7 @@ WAV2VEC2_ASR_LARGE_100H = Wav2Vec2ASRBundle(
},
_labels=utils._get_en_labels(),
_sample_rate=16000,
_normalize_waveform=False,
)
WAV2VEC2_ASR_LARGE_100H.__doc__ = """Wav2vec 2.0 model ("large" architecture with an extra linear module),
pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
......@@ -505,6 +560,7 @@ WAV2VEC2_ASR_LARGE_960H = Wav2Vec2ASRBundle(
},
_labels=utils._get_en_labels(),
_sample_rate=16000,
_normalize_waveform=False,
)
WAV2VEC2_ASR_LARGE_960H.__doc__ = """Wav2vec 2.0 model ("large" architecture with an extra linear module),
pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
......@@ -548,6 +604,7 @@ WAV2VEC2_LARGE_LV60K = Wav2Vec2Bundle(
"aux_num_out": None,
},
_sample_rate=16000,
_normalize_waveform=True,
)
WAV2VEC2_LARGE_LV60K.__doc__ = """Wav2vec 2.0 model ("large-lv60k" architecture),
pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* dataset :cite:`librilight`,
......@@ -591,6 +648,7 @@ WAV2VEC2_ASR_LARGE_LV60K_10M = Wav2Vec2ASRBundle(
},
_labels=utils._get_en_labels(),
_sample_rate=16000,
_normalize_waveform=True,
)
WAV2VEC2_ASR_LARGE_LV60K_10M.__doc__ = """Wav2vec 2.0 model ("large-lv60k" architecture with an extra linear module),
pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* dataset :cite:`librilight`, and
......@@ -634,6 +692,7 @@ WAV2VEC2_ASR_LARGE_LV60K_100H = Wav2Vec2ASRBundle(
},
_labels=utils._get_en_labels(),
_sample_rate=16000,
_normalize_waveform=True,
)
WAV2VEC2_ASR_LARGE_LV60K_100H.__doc__ = """Wav2vec 2.0 model ("large-lv60k" architecture with an extra linear module),
pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* dataset :cite:`librilight`, and
......@@ -678,6 +737,7 @@ WAV2VEC2_ASR_LARGE_LV60K_960H = Wav2Vec2ASRBundle(
},
_labels=utils._get_en_labels(),
_sample_rate=16000,
_normalize_waveform=True,
)
WAV2VEC2_ASR_LARGE_LV60K_960H.__doc__ = """Wav2vec 2.0 model ("large-lv60k" architecture with an extra linear module),
pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* :cite:`librilight` dataset, and
......@@ -721,6 +781,7 @@ WAV2VEC2_XLSR53 = Wav2Vec2Bundle(
"aux_num_out": None,
},
_sample_rate=16000,
_normalize_waveform=True,
)
WAV2VEC2_XLSR53.__doc__ = """Wav2vec 2.0 model ("base" architecture),
pre-trained on 56,000 hours of unlabeled audio from multiple datasets (
......@@ -767,6 +828,7 @@ HUBERT_BASE = Wav2Vec2Bundle(
"aux_num_out": None,
},
_sample_rate=16000,
_normalize_waveform=False,
)
HUBERT_BASE.__doc__ = """HuBERT model ("base" architecture),
pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`
......@@ -809,6 +871,7 @@ HUBERT_LARGE = Wav2Vec2Bundle(
"aux_num_out": None,
},
_sample_rate=16000,
_normalize_waveform=True,
)
HUBERT_LARGE.__doc__ = """HuBERT model ("large" architecture),
pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* dataset :cite:`librilight`,
......@@ -851,6 +914,7 @@ HUBERT_XLARGE = Wav2Vec2Bundle(
"aux_num_out": None,
},
_sample_rate=16000,
_normalize_waveform=True,
)
HUBERT_XLARGE.__doc__ = """HuBERT model ("extra large" architecture),
pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* dataset :cite:`librilight`,
......@@ -894,6 +958,7 @@ HUBERT_ASR_LARGE = Wav2Vec2ASRBundle(
},
_labels=utils._get_en_labels(),
_sample_rate=16000,
_normalize_waveform=True,
)
HUBERT_ASR_LARGE.__doc__ = """HuBERT model ("large" architecture),
pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* dataset :cite:`librilight`, and
......@@ -938,6 +1003,7 @@ HUBERT_ASR_XLARGE = Wav2Vec2ASRBundle(
},
_labels=utils._get_en_labels(),
_sample_rate=16000,
_normalize_waveform=True,
)
HUBERT_ASR_XLARGE.__doc__ = """HuBERT model ("extra large" architecture),
pre-trained on 60,000 hours of unlabeled audio from
......@@ -985,6 +1051,7 @@ VOXPOPULI_ASR_BASE_10K_DE = Wav2Vec2ASRBundle(
},
_labels=utils._get_de_labels(),
_sample_rate=16000,
_normalize_waveform=False,
_remove_aux_axis=(1, 2, 3, 35),
)
VOXPOPULI_ASR_BASE_10K_DE.__doc__ = """wav2vec 2.0 model ("base" architecture),
......@@ -1031,6 +1098,7 @@ VOXPOPULI_ASR_BASE_10K_EN = Wav2Vec2ASRBundle(
},
_labels=utils._get_vp_en_labels(),
_sample_rate=16000,
_normalize_waveform=False,
_remove_aux_axis=(1, 2, 3, 31),
)
VOXPOPULI_ASR_BASE_10K_EN.__doc__ = """wav2vec 2.0 model ("base" architecture),
......@@ -1077,6 +1145,7 @@ VOXPOPULI_ASR_BASE_10K_ES = Wav2Vec2ASRBundle(
},
_labels=utils._get_es_labels(),
_sample_rate=16000,
_normalize_waveform=False,
_remove_aux_axis=(1, 2, 3, 35),
)
VOXPOPULI_ASR_BASE_10K_ES.__doc__ = """wav2vec 2.0 model ("base" architecture),
......@@ -1122,6 +1191,7 @@ VOXPOPULI_ASR_BASE_10K_FR = Wav2Vec2ASRBundle(
},
_labels=utils._get_fr_labels(),
_sample_rate=16000,
_normalize_waveform=False,
)
VOXPOPULI_ASR_BASE_10K_FR.__doc__ = """wav2vec 2.0 model ("base" architecture),
pre-trained on 10k hours of unlabeled audio from *VoxPopuli* dataset :cite:`voxpopuli`
......@@ -1167,6 +1237,7 @@ VOXPOPULI_ASR_BASE_10K_IT = Wav2Vec2ASRBundle(
},
_labels=utils._get_it_labels(),
_sample_rate=16000,
_normalize_waveform=False,
_remove_aux_axis=(1, 2, 3),
)
VOXPOPULI_ASR_BASE_10K_IT.__doc__ = """wav2vec 2.0 model ("base" architecture),
......@@ -1215,6 +1286,7 @@ WAVLM_BASE = Wav2Vec2Bundle(
"model_type": "WavLM",
},
_sample_rate=16000,
_normalize_waveform=False,
)
WAVLM_BASE.__doc__ = """WavLM Base model ("base" architecture),
pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset :cite:`7178964`, not fine-tuned.
......@@ -1260,6 +1332,7 @@ WAVLM_BASE_PLUS = Wav2Vec2Bundle(
"model_type": "WavLM",
},
_sample_rate=16000,
_normalize_waveform=False,
)
WAVLM_BASE_PLUS.__doc__ = """WavLM Base+ model ("base" architecture),
pre-trained on 60,000 hours of Libri-Light dataset :cite:`librilight`, 10,000 hours of GigaSpeech :cite:`GigaSpeech2021`,
......@@ -1306,6 +1379,7 @@ WAVLM_LARGE = Wav2Vec2Bundle(
"model_type": "WavLM",
},
_sample_rate=16000,
_normalize_waveform=True,
)
WAVLM_LARGE.__doc__ = """WavLM Large model ("large" architecture),
pre-trained on 60,000 hours of Libri-Light dataset :cite:`librilight`, 10,000 hours of GigaSpeech :cite:`GigaSpeech2021`,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment