Commit 09aabcc1 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Refactor wav2vec2 pipeline misc helper functions (#3527)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/3527

Reviewed By: huangruizhe

Differential Revision: D48008822

Pulled By: mthrok

fbshipit-source-id: 4beae2956dfd1f00534832b70a1bf0897cba7812
parent 72b0917d
...@@ -27,7 +27,7 @@ RNN-T Streaming/Non-Streaming ASR ...@@ -27,7 +27,7 @@ RNN-T Streaming/Non-Streaming ASR
--------------------------------- ---------------------------------
Interface Interface
^^^^^^^^^ ~~~~~~~~~
``RNNTBundle`` defines ASR pipelines and consists of three steps: feature extraction, inference, and de-tokenization. ``RNNTBundle`` defines ASR pipelines and consists of three steps: feature extraction, inference, and de-tokenization.
...@@ -47,7 +47,7 @@ Interface ...@@ -47,7 +47,7 @@ Interface
.. minigallery:: torchaudio.pipelines.RNNTBundle .. minigallery:: torchaudio.pipelines.RNNTBundle
Pretrained Models Pretrained Models
^^^^^^^^^^^^^^^^^ ~~~~~~~~~~~~~~~~~
.. autosummary:: .. autosummary::
:toctree: generated :toctree: generated
...@@ -61,7 +61,7 @@ wav2vec 2.0 / HuBERT / WavLM - SSL ...@@ -61,7 +61,7 @@ wav2vec 2.0 / HuBERT / WavLM - SSL
---------------------------------- ----------------------------------
Interface Interface
^^^^^^^^^ ~~~~~~~~~
``Wav2Vec2Bundle`` instantiates models that generate acoustic features that can be used for downstream inference and fine-tuning. ``Wav2Vec2Bundle`` instantiates models that generate acoustic features that can be used for downstream inference and fine-tuning.
...@@ -75,7 +75,7 @@ Interface ...@@ -75,7 +75,7 @@ Interface
Wav2Vec2Bundle Wav2Vec2Bundle
Pretrained Models Pretrained Models
^^^^^^^^^^^^^^^^^ ~~~~~~~~~~~~~~~~~
.. autosummary:: .. autosummary::
:toctree: generated :toctree: generated
...@@ -100,7 +100,7 @@ wav2vec 2.0 / HuBERT - Fine-tuned ASR ...@@ -100,7 +100,7 @@ wav2vec 2.0 / HuBERT - Fine-tuned ASR
------------------------------------- -------------------------------------
Interface Interface
^^^^^^^^^ ~~~~~~~~~
``Wav2Vec2ASRBundle`` instantiates models that generate probability distribution over pre-defined labels, that can be used for ASR. ``Wav2Vec2ASRBundle`` instantiates models that generate probability distribution over pre-defined labels, that can be used for ASR.
...@@ -118,7 +118,7 @@ Interface ...@@ -118,7 +118,7 @@ Interface
.. minigallery:: torchaudio.pipelines.Wav2Vec2ASRBundle .. minigallery:: torchaudio.pipelines.Wav2Vec2ASRBundle
Pretrained Models Pretrained Models
^^^^^^^^^^^^^^^^^ ~~~~~~~~~~~~~~~~~
.. autosummary:: .. autosummary::
:toctree: generated :toctree: generated
...@@ -157,7 +157,7 @@ Tacotron2 Text-To-Speech ...@@ -157,7 +157,7 @@ Tacotron2 Text-To-Speech
Similarly ``Vocoder`` can be an algorithm without learning parameters, like `Griffin-Lim`, or a neural-network-based model like `Waveglow`. Similarly ``Vocoder`` can be an algorithm without learning parameters, like `Griffin-Lim`, or a neural-network-based model like `Waveglow`.
Interface Interface
^^^^^^^^^ ~~~~~~~~~
.. autosummary:: .. autosummary::
:toctree: generated :toctree: generated
...@@ -173,7 +173,7 @@ Interface ...@@ -173,7 +173,7 @@ Interface
.. minigallery:: torchaudio.pipelines.Tacotron2TTSBundle .. minigallery:: torchaudio.pipelines.Tacotron2TTSBundle
Pretrained Models Pretrained Models
^^^^^^^^^^^^^^^^^ ~~~~~~~~~~~~~~~~~
.. autosummary:: .. autosummary::
:toctree: generated :toctree: generated
...@@ -189,7 +189,7 @@ Source Separation ...@@ -189,7 +189,7 @@ Source Separation
----------------- -----------------
Interface Interface
^^^^^^^^^ ~~~~~~~~~
``SourceSeparationBundle`` instantiates source separation models which take single channel audio and generates multi-channel audio. ``SourceSeparationBundle`` instantiates source separation models which take single channel audio and generates multi-channel audio.
...@@ -207,7 +207,7 @@ Interface ...@@ -207,7 +207,7 @@ Interface
.. minigallery:: torchaudio.pipelines.SourceSeparationBundle .. minigallery:: torchaudio.pipelines.SourceSeparationBundle
Pretrained Models Pretrained Models
^^^^^^^^^^^^^^^^^ ~~~~~~~~~~~~~~~~~
.. autosummary:: .. autosummary::
:toctree: generated :toctree: generated
......
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, Tuple
import torch from torch.nn import Module
from torch import Tensor
from torch.nn import functional as F, Module
from torchaudio._internal import load_state_dict_from_url
from torchaudio.models import wav2vec2_model, Wav2Vec2Model, wavlm_model
from . import utils from . import utils
__all__ = [] __all__ = [] # type: ignore
class _Wav2Vec2Model(Module):
"""Wrapper class for :py:class:`~torchaudio.models.Wav2Vec2Model`.
This is used for layer normalization at the input
"""
def __init__(self, model: Wav2Vec2Model):
super().__init__()
self.model = model
def forward(self, waveforms: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
waveforms = F.layer_norm(waveforms, waveforms.shape)
return self.model(waveforms, lengths)
@torch.jit.export
def extract_features(
self,
waveforms: Tensor,
lengths: Optional[Tensor] = None,
num_layers: Optional[int] = None,
) -> Tuple[List[Tensor], Optional[Tensor]]:
waveforms = F.layer_norm(waveforms, waveforms.shape)
return self.model.extract_features(waveforms, lengths, num_layers)
@dataclass @dataclass
...@@ -84,10 +55,8 @@ class Wav2Vec2Bundle: ...@@ -84,10 +55,8 @@ class Wav2Vec2Bundle:
return self._sample_rate return self._sample_rate
def _get_state_dict(self, dl_kwargs): def _get_state_dict(self, dl_kwargs):
url = f"https://download.pytorch.org/torchaudio/models/{self._path}" # Note: This method is overridden in ASR bundle
dl_kwargs = {} if dl_kwargs is None else dl_kwargs return utils._get_state_dict(self._path, dl_kwargs)
state_dict = load_state_dict_from_url(url, **dl_kwargs)
return state_dict
def get_model(self, *, dl_kwargs=None) -> Module: def get_model(self, *, dl_kwargs=None) -> Module:
"""Construct the model and load the pretrained weight. """Construct the model and load the pretrained weight.
...@@ -119,13 +88,11 @@ class Wav2Vec2Bundle: ...@@ -119,13 +88,11 @@ class Wav2Vec2Bundle:
- HUBERT_ASR_XLARGE - HUBERT_ASR_XLARGE
- WAVLM_LARGE - WAVLM_LARGE
""" """
if self._model_type == "WavLM": model = utils._get_model(self._model_type, self._params)
model = wavlm_model(**self._params) state_dict = self._get_state_dict(dl_kwargs)
else: model.load_state_dict(state_dict)
model = wav2vec2_model(**self._params)
model.load_state_dict(self._get_state_dict(dl_kwargs))
if self._normalize_waveform: if self._normalize_waveform:
model = _Wav2Vec2Model(model) model = utils._apply_input_layer_norm(model)
model.eval() model.eval()
return model return model
...@@ -171,14 +138,14 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle): ...@@ -171,14 +138,14 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
>>> transcripts = ctc_decode(emissions, labels) >>> transcripts = ctc_decode(emissions, labels)
""" # noqa: E501 """ # noqa: E501
_labels: Tuple[str] _labels: Tuple[str, ...]
_remove_aux_axis: Tuple[int] = (1, 2, 3) _remove_aux_axis: Tuple[int, ...] = (1, 2, 3)
def get_labels( def get_labels(
self, self,
*, *,
blank: str = "-", blank: str = "-",
) -> Tuple[str]: ) -> Tuple[str, ...]:
"""The output class labels (only applicable to fine-tuned bundles) """The output class labels (only applicable to fine-tuned bundles)
The first is blank token, and it is customizable. The first is blank token, and it is customizable.
...@@ -187,7 +154,7 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle): ...@@ -187,7 +154,7 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
blank (str, optional): Blank token. (default: ``'-'``) blank (str, optional): Blank token. (default: ``'-'``)
Returns: Returns:
Tuple[str]: Tuple[str, ...]:
For models fine-tuned on ASR, returns the tuple of strings representing For models fine-tuned on ASR, returns the tuple of strings representing
the output class labels. the output class labels.
...@@ -199,23 +166,7 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle): ...@@ -199,23 +166,7 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
return (blank, *self._labels) return (blank, *self._labels)
def _get_state_dict(self, dl_kwargs): def _get_state_dict(self, dl_kwargs):
state_dict = super()._get_state_dict(dl_kwargs) return utils._get_state_dict(self._path, dl_kwargs, self._remove_aux_axis)
if self._remove_aux_axis:
# Remove the seemingly unnecessary axis
# For ASR task, the pretrained weights originated from fairseq has unrelated dimensions at index 1, 2, 3
# It's originated from the Dictionary implementation of fairseq, which was intended for NLP tasks,
# but not used during the ASR training.
# https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L21-L37
# https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/criterions/ctc.py#L126-L129
#
# Also, some pretrained weights originated from voxpopuli has an extra dimensions that almost never used and
# that resembles mistake.
# The label `1` shows up in the training dataset of German (1 out of 16M),
# English (1 / 28M), Spanish (1 / 9.4M), Romanian (1 / 4.7M) and Polish (6 / 5.8M)
for key in ["aux.weight", "aux.bias"]:
t = state_dict[key]
state_dict[key] = torch.stack([t[i] for i in range(t.size(0)) if i not in self._remove_aux_axis])
return state_dict
WAV2VEC2_BASE = Wav2Vec2Bundle( WAV2VEC2_BASE = Wav2Vec2Bundle(
......
from typing import List, Optional, Tuple
import torch
from torch import nn, Tensor
from torchaudio._internal import load_state_dict_from_url
from torchaudio.models import wav2vec2_model, Wav2Vec2Model, wavlm_model
def _get_model(type_, params):
factories = {
"Wav2Vec2": wav2vec2_model,
"WavLM": wavlm_model,
}
if type_ not in factories:
raise ValueError(f"Supported model types are {tuple(factories.keys())}. Found: {type_}")
factory = factories[type_]
return factory(**params)
class _Wav2Vec2Model(nn.Module):
"""Wrapper class for :py:class:`~torchaudio.models.Wav2Vec2Model`.
This is used for layer normalization at the input
"""
def __init__(self, model: Wav2Vec2Model):
super().__init__()
self.model = model
def forward(self, waveforms: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
waveforms = nn.functional.layer_norm(waveforms, waveforms.shape)
return self.model(waveforms, lengths)
@torch.jit.export
def extract_features(
self,
waveforms: Tensor,
lengths: Optional[Tensor] = None,
num_layers: Optional[int] = None,
) -> Tuple[List[Tensor], Optional[Tensor]]:
waveforms = nn.functional.layer_norm(waveforms, waveforms.shape)
return self.model.extract_features(waveforms, lengths, num_layers)
def _apply_input_layer_norm(module):
"""Add extra layer_norm to the model"""
return _Wav2Vec2Model(module)
def _remove_aux_axes(state_dict, axes):
# Remove the seemingly unnecessary axis
# For ASR task, the pretrained weights originated from fairseq has unrelated dimensions at index 1, 2, 3
# It's originated from the Dictionary implementation of fairseq, which was intended for NLP tasks,
# but not used during the ASR training.
# https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L21-L37
# https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/criterions/ctc.py#L126-L129
#
# Also, some pretrained weights originated from voxpopuli has an extra dimensions that almost never used and
# that resembles mistake.
# The label `1` shows up in the training dataset of German (1 out of 16M),
# English (1 / 28M), Spanish (1 / 9.4M), Romanian (1 / 4.7M) and Polish (6 / 5.8M)
for key in ["aux.weight", "aux.bias"]:
mat = state_dict[key]
state_dict[key] = torch.stack([mat[i] for i in range(mat.size(0)) if i not in axes])
def _get_state_dict(url, dl_kwargs, remove_axes=None):
if not url.startswith("https"):
url = f"https://download.pytorch.org/torchaudio/models/{url}"
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs)
if remove_axes:
_remove_aux_axes(state_dict, remove_axes)
return state_dict
def _get_en_labels(): def _get_en_labels():
return ( return (
"|", "|",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment