Commit 09aabcc1 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Refactor wav2vec2 pipeline misc helper functions (#3527)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/3527

Reviewed By: huangruizhe

Differential Revision: D48008822

Pulled By: mthrok

fbshipit-source-id: 4beae2956dfd1f00534832b70a1bf0897cba7812
parent 72b0917d
......@@ -27,7 +27,7 @@ RNN-T Streaming/Non-Streaming ASR
---------------------------------
Interface
^^^^^^^^^
~~~~~~~~~
``RNNTBundle`` defines ASR pipelines and consists of three steps: feature extraction, inference, and de-tokenization.
......@@ -47,7 +47,7 @@ Interface
.. minigallery:: torchaudio.pipelines.RNNTBundle
Pretrained Models
^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
......@@ -61,7 +61,7 @@ wav2vec 2.0 / HuBERT / WavLM - SSL
----------------------------------
Interface
^^^^^^^^^
~~~~~~~~~
``Wav2Vec2Bundle`` instantiates models that generate acoustic features that can be used for downstream inference and fine-tuning.
......@@ -75,7 +75,7 @@ Interface
Wav2Vec2Bundle
Pretrained Models
^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
......@@ -100,7 +100,7 @@ wav2vec 2.0 / HuBERT - Fine-tuned ASR
-------------------------------------
Interface
^^^^^^^^^
~~~~~~~~~
``Wav2Vec2ASRBundle`` instantiates models that generate probability distribution over pre-defined labels, that can be used for ASR.
......@@ -118,7 +118,7 @@ Interface
.. minigallery:: torchaudio.pipelines.Wav2Vec2ASRBundle
Pretrained Models
^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
......@@ -157,7 +157,7 @@ Tacotron2 Text-To-Speech
Similarly ``Vocoder`` can be an algorithm without learning parameters, like `Griffin-Lim`, or a neural-network-based model like `Waveglow`.
Interface
^^^^^^^^^
~~~~~~~~~
.. autosummary::
:toctree: generated
......@@ -173,7 +173,7 @@ Interface
.. minigallery:: torchaudio.pipelines.Tacotron2TTSBundle
Pretrained Models
^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
......@@ -189,7 +189,7 @@ Source Separation
-----------------
Interface
^^^^^^^^^
~~~~~~~~~
``SourceSeparationBundle`` instantiates source separation models which take single channel audio and generates multi-channel audio.
......@@ -207,7 +207,7 @@ Interface
.. minigallery:: torchaudio.pipelines.SourceSeparationBundle
Pretrained Models
^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
......
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, Tuple
import torch
from torch import Tensor
from torch.nn import functional as F, Module
from torchaudio._internal import load_state_dict_from_url
from torchaudio.models import wav2vec2_model, Wav2Vec2Model, wavlm_model
from torch.nn import Module
from . import utils
__all__ = []
class _Wav2Vec2Model(Module):
"""Wrapper class for :py:class:`~torchaudio.models.Wav2Vec2Model`.
This is used for layer normalization at the input
"""
def __init__(self, model: Wav2Vec2Model):
super().__init__()
self.model = model
def forward(self, waveforms: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
waveforms = F.layer_norm(waveforms, waveforms.shape)
return self.model(waveforms, lengths)
@torch.jit.export
def extract_features(
self,
waveforms: Tensor,
lengths: Optional[Tensor] = None,
num_layers: Optional[int] = None,
) -> Tuple[List[Tensor], Optional[Tensor]]:
waveforms = F.layer_norm(waveforms, waveforms.shape)
return self.model.extract_features(waveforms, lengths, num_layers)
__all__ = [] # type: ignore
@dataclass
......@@ -84,10 +55,8 @@ class Wav2Vec2Bundle:
return self._sample_rate
def _get_state_dict(self, dl_kwargs):
url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs)
return state_dict
# Note: This method is overridden in ASR bundle
return utils._get_state_dict(self._path, dl_kwargs)
def get_model(self, *, dl_kwargs=None) -> Module:
"""Construct the model and load the pretrained weight.
......@@ -119,13 +88,11 @@ class Wav2Vec2Bundle:
- HUBERT_ASR_XLARGE
- WAVLM_LARGE
"""
if self._model_type == "WavLM":
model = wavlm_model(**self._params)
else:
model = wav2vec2_model(**self._params)
model.load_state_dict(self._get_state_dict(dl_kwargs))
model = utils._get_model(self._model_type, self._params)
state_dict = self._get_state_dict(dl_kwargs)
model.load_state_dict(state_dict)
if self._normalize_waveform:
model = _Wav2Vec2Model(model)
model = utils._apply_input_layer_norm(model)
model.eval()
return model
......@@ -171,14 +138,14 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
>>> transcripts = ctc_decode(emissions, labels)
""" # noqa: E501
_labels: Tuple[str]
_remove_aux_axis: Tuple[int] = (1, 2, 3)
_labels: Tuple[str, ...]
_remove_aux_axis: Tuple[int, ...] = (1, 2, 3)
def get_labels(
self,
*,
blank: str = "-",
) -> Tuple[str]:
) -> Tuple[str, ...]:
"""The output class labels (only applicable to fine-tuned bundles)
The first is blank token, and it is customizable.
......@@ -187,7 +154,7 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
blank (str, optional): Blank token. (default: ``'-'``)
Returns:
Tuple[str]:
Tuple[str, ...]:
For models fine-tuned on ASR, returns the tuple of strings representing
the output class labels.
......@@ -199,23 +166,7 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
return (blank, *self._labels)
def _get_state_dict(self, dl_kwargs):
state_dict = super()._get_state_dict(dl_kwargs)
if self._remove_aux_axis:
# Remove the seemingly unnecessary axis
# For ASR task, the pretrained weights originated from fairseq has unrelated dimensions at index 1, 2, 3
# It's originated from the Dictionary implementation of fairseq, which was intended for NLP tasks,
# but not used during the ASR training.
# https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L21-L37
# https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/criterions/ctc.py#L126-L129
#
# Also, some pretrained weights originated from voxpopuli has an extra dimensions that almost never used and
# that resembles mistake.
# The label `1` shows up in the training dataset of German (1 out of 16M),
# English (1 / 28M), Spanish (1 / 9.4M), Romanian (1 / 4.7M) and Polish (6 / 5.8M)
for key in ["aux.weight", "aux.bias"]:
t = state_dict[key]
state_dict[key] = torch.stack([t[i] for i in range(t.size(0)) if i not in self._remove_aux_axis])
return state_dict
return utils._get_state_dict(self._path, dl_kwargs, self._remove_aux_axis)
WAV2VEC2_BASE = Wav2Vec2Bundle(
......
from typing import List, Optional, Tuple
import torch
from torch import nn, Tensor
from torchaudio._internal import load_state_dict_from_url
from torchaudio.models import wav2vec2_model, Wav2Vec2Model, wavlm_model
def _get_model(type_, params):
factories = {
"Wav2Vec2": wav2vec2_model,
"WavLM": wavlm_model,
}
if type_ not in factories:
raise ValueError(f"Supported model types are {tuple(factories.keys())}. Found: {type_}")
factory = factories[type_]
return factory(**params)
class _Wav2Vec2Model(nn.Module):
"""Wrapper class for :py:class:`~torchaudio.models.Wav2Vec2Model`.
This is used for layer normalization at the input
"""
def __init__(self, model: Wav2Vec2Model):
super().__init__()
self.model = model
def forward(self, waveforms: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
waveforms = nn.functional.layer_norm(waveforms, waveforms.shape)
return self.model(waveforms, lengths)
@torch.jit.export
def extract_features(
self,
waveforms: Tensor,
lengths: Optional[Tensor] = None,
num_layers: Optional[int] = None,
) -> Tuple[List[Tensor], Optional[Tensor]]:
waveforms = nn.functional.layer_norm(waveforms, waveforms.shape)
return self.model.extract_features(waveforms, lengths, num_layers)
def _apply_input_layer_norm(module):
"""Add extra layer_norm to the model"""
return _Wav2Vec2Model(module)
def _remove_aux_axes(state_dict, axes):
# Remove the seemingly unnecessary axis
# For ASR task, the pretrained weights originated from fairseq has unrelated dimensions at index 1, 2, 3
# It's originated from the Dictionary implementation of fairseq, which was intended for NLP tasks,
# but not used during the ASR training.
# https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L21-L37
# https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/criterions/ctc.py#L126-L129
#
# Also, some pretrained weights originated from voxpopuli has an extra dimensions that almost never used and
# that resembles mistake.
# The label `1` shows up in the training dataset of German (1 out of 16M),
# English (1 / 28M), Spanish (1 / 9.4M), Romanian (1 / 4.7M) and Polish (6 / 5.8M)
for key in ["aux.weight", "aux.bias"]:
mat = state_dict[key]
state_dict[key] = torch.stack([mat[i] for i in range(mat.size(0)) if i not in axes])
def _get_state_dict(url, dl_kwargs, remove_axes=None):
if not url.startswith("https"):
url = f"https://download.pytorch.org/torchaudio/models/{url}"
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs)
if remove_axes:
_remove_aux_axes(state_dict, remove_axes)
return state_dict
def _get_en_labels():
return (
"|",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment