Commit 5e211d66 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add MMS FA Bundle (#3521)

Summary:
Port the MMS FA model from tutorial to the library with post-processing module.

Pull Request resolved: https://github.com/pytorch/audio/pull/3521

Reviewed By: huangruizhe

Differential Revision: D48038285

Pulled By: mthrok

fbshipit-source-id: 571cf0fceaaab4790983be2719f1a85805b814f5
parent 30668afb
...@@ -5,8 +5,13 @@ ...@@ -5,8 +5,13 @@
.. autoclass:: {{ fullname }}() .. autoclass:: {{ fullname }}()
{%- if name in ["RNNTBundle.FeatureExtractor", "RNNTBundle.TokenProcessor"] %} {%- set support_classes = [] %}
{%- if name in ["RNNTBundle.FeatureExtractor", "RNNTBundle.TokenProcessor", "Wav2Vec2FABundle.Tokenizer"] %}
{%- set methods = ["__call__"] %} {%- set methods = ["__call__"] %}
{%- elif name == "Wav2Vec2FABundle.Aligner" %}
{%- set attributes = [] %}
{%- set methods = ["__call__"] %}
{%- set support_classes = ["Token"] %}
{%- elif name == "Tacotron2TTSBundle.TextProcessor" %} {%- elif name == "Tacotron2TTSBundle.TextProcessor" %}
{%- set attributes = ["tokens"] %} {%- set attributes = ["tokens"] %}
{%- set methods = ["__call__"] %} {%- set methods = ["__call__"] %}
...@@ -21,12 +26,17 @@ ...@@ -21,12 +26,17 @@
{%- set methods = ["__call__"] %} {%- set methods = ["__call__"] %}
{% endif %} {% endif %}
.. {%- if attributes %}
ATTRIBUTES
Properties
----------
{%- endif %}
{%- for item in attributes %} {%- for item in attributes %}
{%- if not item.startswith('_') %} {%- if not item.startswith('_') %}
{{ item | underline("-") }} {{ item | underline("~") }}
.. container:: py attribute .. container:: py attribute
...@@ -35,13 +45,17 @@ ...@@ -35,13 +45,17 @@
{%- endif %} {%- endif %}
{%- endfor %} {%- endfor %}
.. {%- if methods %}
METHODS
Methods
-------
{%- endif %}
{%- for item in methods %} {%- for item in methods %}
{%- if item != "__init__" %} {%- if item != "__init__" %}
{{item | underline("-") }} {{item | underline("~") }}
.. container:: py attribute .. container:: py attribute
...@@ -49,3 +63,24 @@ ...@@ -49,3 +63,24 @@
{%- endif %} {%- endif %}
{%- endfor %} {%- endfor %}
{%- if support_classes %}
Support Structures
------------------
{%- endif %}
{%- for item in support_classes %}
{% set components = item.split('.') %}
{{ components[-1] | underline("~") }}
.. container:: py attribute
.. autoclass:: {{[fullname, item] | join('.')}}
:members:
{%- endfor %}
...@@ -142,6 +142,38 @@ Pretrained Models ...@@ -142,6 +142,38 @@ Pretrained Models
HUBERT_ASR_LARGE HUBERT_ASR_LARGE
HUBERT_ASR_XLARGE HUBERT_ASR_XLARGE
wav2vec 2.0 / HuBERT - Forced Alignment
---------------------------------------
Interface
~~~~~~~~~
``Wav2Vec2FABundle`` bundles pre-trained model and its associated dictionary. Additionally, it supports appending ``star`` token dimension.
.. image:: https://download.pytorch.org/torchaudio/doc-assets/pipelines-wav2vec2fabundle.png
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_class.rst
Wav2Vec2FABundle
Wav2Vec2FABundle.Tokenizer
Wav2Vec2FABundle.Aligner
.. rubric:: Tutorials using ``Wav2Vec2FABundle``
.. minigallery:: torchaudio.pipelines.Wav2Vec2FABundle
Pertrained Models
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_data.rst
MMS_FA
.. _Tacotron2: .. _Tacotron2:
......
...@@ -570,3 +570,12 @@ year = {2017}, ...@@ -570,3 +570,12 @@ year = {2017},
URL = {https://arxiv.org/abs/1609.09430}, URL = {https://arxiv.org/abs/1609.09430},
booktitle = {International Conference on Acoustics, Speech and Signal Processing (ICASSP)} booktitle = {International Conference on Acoustics, Speech and Signal Processing (ICASSP)}
} }
@misc{pratap2023scaling,
title={Scaling Speech Technology to 1,000+ Languages},
author={Vineel Pratap and Andros Tjandra and Bowen Shi and Paden Tomasello and Arun Babu and Sayani Kundu and Ali Elkahky and Zhaoheng Ni and Apoorv Vyas and Maryam Fazel-Zarandi and Alexei Baevski and Yossi Adi and Xiaohui Zhang and Wei-Ning Hsu and Alexis Conneau and Michael Auli},
year={2023},
eprint={2305.13516},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
...@@ -18,6 +18,7 @@ from ._wav2vec2.impl import ( ...@@ -18,6 +18,7 @@ from ._wav2vec2.impl import (
HUBERT_BASE, HUBERT_BASE,
HUBERT_LARGE, HUBERT_LARGE,
HUBERT_XLARGE, HUBERT_XLARGE,
MMS_FA,
VOXPOPULI_ASR_BASE_10K_DE, VOXPOPULI_ASR_BASE_10K_DE,
VOXPOPULI_ASR_BASE_10K_EN, VOXPOPULI_ASR_BASE_10K_EN,
VOXPOPULI_ASR_BASE_10K_ES, VOXPOPULI_ASR_BASE_10K_ES,
...@@ -41,6 +42,7 @@ from ._wav2vec2.impl import ( ...@@ -41,6 +42,7 @@ from ._wav2vec2.impl import (
WAV2VEC2_XLSR_300M, WAV2VEC2_XLSR_300M,
Wav2Vec2ASRBundle, Wav2Vec2ASRBundle,
Wav2Vec2Bundle, Wav2Vec2Bundle,
Wav2Vec2FABundle,
WAVLM_BASE, WAVLM_BASE,
WAVLM_BASE_PLUS, WAVLM_BASE_PLUS,
WAVLM_LARGE, WAVLM_LARGE,
...@@ -51,6 +53,7 @@ from .rnnt_pipeline import EMFORMER_RNNT_BASE_LIBRISPEECH, RNNTBundle ...@@ -51,6 +53,7 @@ from .rnnt_pipeline import EMFORMER_RNNT_BASE_LIBRISPEECH, RNNTBundle
__all__ = [ __all__ = [
"Wav2Vec2Bundle", "Wav2Vec2Bundle",
"Wav2Vec2ASRBundle", "Wav2Vec2ASRBundle",
"Wav2Vec2FABundle",
"WAV2VEC2_BASE", "WAV2VEC2_BASE",
"WAV2VEC2_LARGE", "WAV2VEC2_LARGE",
"WAV2VEC2_LARGE_LV60K", "WAV2VEC2_LARGE_LV60K",
...@@ -77,6 +80,7 @@ __all__ = [ ...@@ -77,6 +80,7 @@ __all__ = [
"HUBERT_XLARGE", "HUBERT_XLARGE",
"HUBERT_ASR_LARGE", "HUBERT_ASR_LARGE",
"HUBERT_ASR_XLARGE", "HUBERT_ASR_XLARGE",
"MMS_FA",
"WAVLM_BASE", "WAVLM_BASE",
"WAVLM_BASE_PLUS", "WAVLM_BASE_PLUS",
"WAVLM_LARGE", "WAVLM_LARGE",
......
from abc import ABC, abstractmethod
from typing import Dict, List
import torch
import torchaudio.functional as F
from torch import Tensor
from torchaudio.functional import TokenSpan
class ITokenizer(ABC):
@abstractmethod
def __call__(self, transcript: List[str]) -> List[List[str]]:
"""Tokenize the given transcript (list of word)
.. note::
The toranscript must be normalized.
Args:
transcript (list of str): Transcript (list of word).
Returns:
(list of int): List of token sequences
"""
class Tokenizer(ITokenizer):
def __init__(self, dictionary: Dict[str, int]):
self.dictionary = dictionary
def __call__(self, transcript: List[str]) -> List[List[int]]:
return [[self.dictionary[c] for c in word] for word in transcript]
def _align_emission_and_tokens(emission: Tensor, tokens: List[int]):
device = emission.device
emission = emission.unsqueeze(0)
targets = torch.tensor([tokens], dtype=torch.int32, device=device)
aligned_tokens, scores = F.forced_align(emission, targets, 0)
scores = scores.exp() # convert back to probability
aligned_tokens, scores = aligned_tokens[0], scores[0] # remove batch dimension
return aligned_tokens, scores
class IAligner(ABC):
@abstractmethod
def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]:
"""Generate list of time-stamped token sequences
Args:
emission (Tensor): Sequence of token probability distributions.
Shape: `(time, tokens)`.
tokens (list of integer sequence): Tokenized transcript.
Output from :py:class:`Wav2Vec2FABundle.Tokenizer`.
Returns:
(list of TokenSpan sequence): Tokens with time stamps and scores.
"""
def _unflatten(list_, lengths):
assert len(list_) == sum(lengths)
i = 0
ret = []
for l in lengths:
ret.append(list_[i : i + l])
i += l
return ret
def _flatten(nested_list):
return [item for list_ in nested_list for item in list_]
class Aligner(IAligner):
def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]:
if emission.ndim != 2:
raise ValueError(f"The input emission must be 2D. Found: {emission.shape}")
emission = torch.log_softmax(emission, dim=-1)
aligned_tokens, scores = _align_emission_and_tokens(emission, _flatten(tokens))
spans = F.merge_tokens(aligned_tokens, scores)
return _unflatten(spans, [len(ts) for ts in tokens])
import copy
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Tuple from typing import Any, Dict, Optional, Tuple
from torch.nn import Module from torch.nn import Module
from . import utils from . import aligner, utils
__all__ = [] # type: ignore __all__ = [] # type: ignore
...@@ -146,7 +147,7 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle): ...@@ -146,7 +147,7 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
*, *,
blank: str = "-", blank: str = "-",
) -> Tuple[str, ...]: ) -> Tuple[str, ...]:
"""The output class labels (only applicable to fine-tuned bundles) """The output class labels.
The first is blank token, and it is customizable. The first is blank token, and it is customizable.
...@@ -159,8 +160,8 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle): ...@@ -159,8 +160,8 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
the output class labels. the output class labels.
Example Example
>>> import torchaudio >>> from torchaudio.pipelines import HUBERT_ASR_LARGE as bundle
>>> torchaudio.models.HUBERT_ASR_LARGE.get_labels() >>> bundle.get_labels()
('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z') ('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
""" # noqa: E501 """ # noqa: E501
return (blank, *self._labels) return (blank, *self._labels)
...@@ -1518,3 +1519,181 @@ redistributed with the same license. ...@@ -1518,3 +1519,181 @@ redistributed with the same license.
Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for usage details. Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for usage details.
""" # noqa: E501 """ # noqa: E501
@dataclass
class Wav2Vec2FABundle(Wav2Vec2ASRBundle):
"""Data class that bundles associated information to use pretrained :py:class:`~torchaudio.models.Wav2Vec2Model` for forced alignment.
This class provides interfaces for instantiating the pretrained model along with
the information necessary to retrieve pretrained weights and additional data
to be used with the model.
Torchaudio library instantiates objects of this class, each of which represents
a different pretrained model. Client code should access pretrained models via these
instances.
Please see below for the usage and the available values.
Example - Feature Extraction
>>> import torchaudio
>>>
>>> bundle = torchaudio.pipelines.MMS_FA
>>>
>>> # Build the model and load pretrained weight.
>>> model = bundle.get_model()
Downloading:
100%|███████████████████████████████| 1.18G/1.18G [00:05<00:00, 216MB/s]
>>>
>>> # Resample audio to the expected sampling rate
>>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
>>>
>>> # Estimate the probability of token distribution
>>> emission, _ = model(waveform)
>>>
>>> # Generate frame-wise alignment
>>> alignment, scores = torchaudio.functional.forced_align(
>>> emission, targets, input_lengths, target_lengths, blank=0)
>>>
""" # noqa: E501
class Tokenizer(aligner.ITokenizer):
"""Interface of the tokenizer"""
class Aligner(aligner.IAligner):
"""Interface of the aligner"""
def get_labels(self, star: Optional[str] = "*", blank: str = "-") -> Tuple[str, ...]:
"""Get the labels corresponding to the feature dimension of emission.
The first is blank token, and it is customizable.
Args:
star (str or None, optional): Change or disable star token. (default: ``"*"``)
blank (str, optional): Change the blank token. (default: ``'-'``)
Returns:
Tuple[str, ...]:
For models fine-tuned on ASR, returns the tuple of strings representing
the output class labels.
Example
>>> from torchaudio.pipelines import MMS_FA as bundle
>>> bundle.get_labels()
('-', 'a', 'i', 'e', 'n', 'o', 'u', 't', 's', 'r', 'm', 'k', 'l', 'd', 'g', 'h', 'y', 'b', 'p', 'w', 'c', 'v', 'j', 'z', 'f', "'", 'q', 'x', '*')
>>> bundle.get_labels(star=None)
('-', 'a', 'i', 'e', 'n', 'o', 'u', 't', 's', 'r', 'm', 'k', 'l', 'd', 'g', 'h', 'y', 'b', 'p', 'w', 'c', 'v', 'j', 'z', 'f', "'", 'q', 'x')
""" # noqa: E501
labels = super().get_labels(blank=blank)
return labels if star is None else (*labels, star)
def _get_params_with_star(self):
params = copy.deepcopy(self._params)
params["aux_num_out"] += 1
return params
def get_model(self, with_star: bool = True, *, dl_kwargs=None) -> Module:
"""Construct the model and load the pretrained weight.
The weight file is downloaded from the internet and cached with
:func:`torch.hub.load_state_dict_from_url`
Args:
with_star (bool, optional): If enabled, the last dimension of output layer is
extended by one, which corresponds to `star` token.
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
Returns:
Variation of :py:class:`~torchaudio.models.Wav2Vec2Model`.
"""
params = self._get_params_with_star() if with_star else self._params
model = utils._get_model(self._model_type, params)
state_dict = utils._get_state_dict(self._path, dl_kwargs, self._remove_aux_axis, with_star)
model.load_state_dict(state_dict)
if self._normalize_waveform:
model = utils._apply_input_layer_norm(model)
model.eval()
return model
def get_dict(self, star: Optional[str] = "*", blank: str = "-") -> Dict[str, int]:
"""Get the mapping from token to index (in emission feature dim)
Args:
star (str or None, optional): Change or disable star token. (default: ``"*"``)
blank (str, optional): Change the blank token. (default: ``'-'``)
Returns:
Tuple[str, ...]:
For models fine-tuned on ASR, returns the tuple of strings representing
the output class labels.
Example
>>> from torchaudio.pipelines import MMS_FA as bundle
>>> bundle.get_dict()
{'-': 0, 'a': 1, 'i': 2, 'e': 3, 'n': 4, 'o': 5, 'u': 6, 't': 7, 's': 8, 'r': 9, 'm': 10, 'k': 11, 'l': 12, 'd': 13, 'g': 14, 'h': 15, 'y': 16, 'b': 17, 'p': 18, 'w': 19, 'c': 20, 'v': 21, 'j': 22, 'z': 23, 'f': 24, "'": 25, 'q': 26, 'x': 27, '*': 28}
>>> bundle.get_dict(star=None)
{'-': 0, 'a': 1, 'i': 2, 'e': 3, 'n': 4, 'o': 5, 'u': 6, 't': 7, 's': 8, 'r': 9, 'm': 10, 'k': 11, 'l': 12, 'd': 13, 'g': 14, 'h': 15, 'y': 16, 'b': 17, 'p': 18, 'w': 19, 'c': 20, 'v': 21, 'j': 22, 'z': 23, 'f': 24, "'": 25, 'q': 26, 'x': 27}
""" # noqa: E501
return {k: i for i, k in enumerate(self.get_labels(star=star, blank=blank))}
def get_tokenizer(self) -> Tokenizer:
"""Instantiate a Tokenizer.
Returns:
Tokenizer
"""
return aligner.Tokenizer(self.get_dict())
def get_aligner(self) -> Aligner:
"""Instantiate an Aligner.
Returns:
Aligner
"""
return aligner.Aligner()
MMS_FA = Wav2Vec2FABundle(
"https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt",
{
"extractor_mode": "layer_norm",
"extractor_conv_layer_config": [
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 2, 2),
(512, 2, 2),
],
"extractor_conv_bias": True,
"encoder_embed_dim": 1024,
"encoder_projection_dropout": 0.0,
"encoder_pos_conv_kernel": 128,
"encoder_pos_conv_groups": 16,
"encoder_num_layers": 24,
"encoder_num_heads": 16,
"encoder_attention_dropout": 0.0,
"encoder_ff_interm_features": 4096,
"encoder_ff_interm_dropout": 0.1,
"encoder_dropout": 0.0,
"encoder_layer_norm_first": True,
"encoder_layer_drop": 0.1,
"aux_num_out": 28,
},
_labels=utils._get_mms_labels(),
_sample_rate=16000,
_normalize_waveform=True,
_model_type="Wav2Vec2",
)
MMS_FA.__doc__ = """
Trained on 31K hours of data in 1,130 languages from *Scaling Speech Technology to 1,000+ Languages* :cite:`pratap2023scaling`.
Published by the authors of *Scaling Speech Technology to 1,000+ Languages* :cite:`pratap2023scaling` under [`CC-BY-NC 4.0 License <https://github.com/facebookresearch/fairseq/tree/100cd91db19bb27277a06a25eb4154c805b10189/examples/mms#license>`__].
Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2FABundle` for usage details.
.. note::
Unlike other Wav2Vec2 bundles, this model does not have a token for word boundary (like `|`). This makes the post-processing of alignments slightly different.
""" # noqa: E501
...@@ -65,13 +65,23 @@ def _remove_aux_axes(state_dict, axes): ...@@ -65,13 +65,23 @@ def _remove_aux_axes(state_dict, axes):
state_dict[key] = torch.stack([mat[i] for i in range(mat.size(0)) if i not in axes]) state_dict[key] = torch.stack([mat[i] for i in range(mat.size(0)) if i not in axes])
def _get_state_dict(url, dl_kwargs, remove_axes=None): def _add_star_dim(state_dict):
w, b = state_dict["aux.weight"], state_dict["aux.bias"]
zeros = torch.zeros((1, w.size(1)), device=w.device, dtype=w.dtype)
state_dict["aux.weight"] = torch.cat((zeros, w), dim=0)
ones = torch.ones((1,), device=b.device, dtype=b.dtype)
state_dict["aux.bias"] = torch.cat((b, ones), dim=0)
def _get_state_dict(url, dl_kwargs, remove_axes=None, add_star=False):
if not url.startswith("https"): if not url.startswith("https"):
url = f"https://download.pytorch.org/torchaudio/models/{url}" url = f"https://download.pytorch.org/torchaudio/models/{url}"
dl_kwargs = {} if dl_kwargs is None else dl_kwargs dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs) state_dict = load_state_dict_from_url(url, **dl_kwargs)
if remove_axes: if remove_axes:
_remove_aux_axes(state_dict, remove_axes) _remove_aux_axes(state_dict, remove_axes)
if add_star:
_add_star_dim(state_dict)
return state_dict return state_dict
...@@ -301,3 +311,35 @@ def _get_it_labels(): ...@@ -301,3 +311,35 @@ def _get_it_labels():
"í", "í",
"ï", "ï",
) )
def _get_mms_labels():
return (
"a",
"i",
"e",
"n",
"o",
"u",
"t",
"s",
"r",
"m",
"k",
"l",
"d",
"g",
"h",
"y",
"b",
"p",
"w",
"c",
"v",
"j",
"z",
"f",
"'",
"q",
"x",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment