"docs/vscode:/vscode.git/clone" did not exist on "fbd9770a9b7c5bc2eb2dbf66b43e1c54ef165185"
Unverified Commit 358e9e93 authored by moto's avatar moto Committed by GitHub
Browse files

Add HUBERT_BASE and HUBERT_ASR_LARGE pretrained models (#1821)

parent 8c262c14
#!/usr/bin/env python3
"""Convert a Wav2Vec2/HuBERT model published by fairseq into torchaudio format
Examples
```
python convert_fairseq_models.py \
--input-file hubert_base_ls960.pt \
--output-file hubert_fairseq_base_ls960.pth
python convert_fairseq_models.py \
--input-file hubert_large_ll60k.pt \
--output-file hubert_fairseq_large_ll60k.pth
python convert_fairseq_models.py \
--input-file hubert_large_ll60k_finetune_ls960.pt \
--output-file hubert_fairseq_large_ll60k_asr_ls960.pth
python convert_fairseq_models.py \
--input-file hubert_xtralarge_ll60k.pt \
--output-file hubert_fairseq_xlarge_ll60k.pth
python convert_fairseq_models.py \
--input-file hubert_xtralarge_ll60k_finetune_ls960.pt \
--output-file hubert_fairseq_xlarge_ll60k_asr_ls960.pth
"""
import argparse
# Note: Avoiding the import of torch and fairseq on global scope as they are slow
def _parse_args():
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
'--input-file', required=True,
help='Input model file.'
)
parser.add_argument(
'--output-file', required=False,
help='Output model file.'
)
parser.add_argument(
'--dict-dir',
help=(
'Directory where letter vocabulary file, `dict.ltr.txt`, is found. '
'Required when loading wav2vec2 model. '
'https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt'
)
)
return parser.parse_args()
def _load_model(input_file, dict_dir):
import fairseq
overrides = {} if dict_dir is None else {'data': dict_dir}
models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[input_file], arg_overrides=overrides,
)
return models[0]
def _import_model(model):
from torchaudio.models.wav2vec2.utils import import_fairseq_model
if model.__class__.__name__ in ['HubertCtc', 'Wav2VecCtc']:
model = model.w2v_encoder
model = import_fairseq_model(model)
return model
def _main(args):
import torch
model = _load_model(args.input_file, args.dict_dir)
model = _import_model(model)
torch.save(model.state_dict(), args.output_file)
if __name__ == '__main__':
_main(_parse_args())
......@@ -130,11 +130,26 @@ hubert_ft_xlarge
.. autofunction:: hubert_ft_xlarge
.. currentmodule:: torchaudio.models.wav2vec2.utils
Pre-trained Models
------------------
.. autoclass:: Wav2Vec2PretrainedModelBundle
.. automethod:: get_model
.. autoproperty:: labels
.. autodata:: HUBERT_BASE
:no-value:
.. autodata:: HUBERT_ASR_LARGE
:no-value:
Utility Functions
-----------------
.. currentmodule:: torchaudio.models.wav2vec2.utils
import_huggingface_model
^^^^^^^^^^^^^^^^^^^^^^^^
......
@misc{conneau2020unsupervised,
title={Unsupervised Cross-lingual Representation Learning for Speech Recognition},
author={Alexis Conneau and Alexei Baevski and Ronan Collobert and Abdelrahman Mohamed and Michael Auli},
year={2020},
eprint={2006.13979},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
@inproceedings{Gales2014SpeechRA,
title={Speech recognition and keyword spotting for low-resource languages: Babel project research at CUED},
author={Mark John Francis Gales and Kate Knill and Anton Ragni and Shakti Prasad Rath},
booktitle={SLTU},
year={2014}
}
@misc{ardila2020common,
title={Common Voice: A Massively-Multilingual Speech Corpus},
author={Rosana Ardila and Megan Branson and Kelly Davis and Michael Henretty and Michael Kohler and Josh Meyer and Reuben Morais and Lindsay Saunders and Francis M. Tyers and Gregor Weber},
year={2020},
eprint={1912.06670},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
@article{Pratap_2020,
title={MLS: A Large-Scale Multilingual Dataset for Speech Research},
url={http://dx.doi.org/10.21437/Interspeech.2020-2826},
DOI={10.21437/interspeech.2020-2826},
journal={Interspeech 2020},
publisher={ISCA},
author={Pratap, Vineel and Xu, Qiantong and Sriram, Anuroop and Synnaeve, Gabriel and Collobert, Ronan},
year={2020},
month={Oct}
}
@INPROCEEDINGS{librilight,
author={J. {Kahn} and M. {Rivière} and W. {Zheng} and E. {Kharitonov} and Q. {Xu} and P. E. {Mazaré} and J. {Karadayi} and V. {Liptchinsky} and R. {Collobert} and C. {Fuegen} and T. {Likhomanenko} and G. {Synnaeve} and A. {Joulin} and A. {Mohamed} and E. {Dupoux}},
booktitle={ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
title={Libri-Light: A Benchmark for ASR with Limited or No Supervision},
year={2020},
pages={7669-7673},
note = {\url{https://github.com/facebookresearch/libri-light}},
}
@INPROCEEDINGS{7178964,
author={Panayotov, Vassil and Chen, Guoguo and Povey, Daniel and Khudanpur, Sanjeev},
booktitle={2015 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
title={Librispeech: An ASR corpus based on public domain audio books},
year={2015},
volume={},
number={},
pages={5206-5210},
doi={10.1109/ICASSP.2015.7178964}
}
@inproceedings{ott2019fairseq,
title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
year = {2019},
}
@misc{baevski2020wav2vec,
title={wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations},
author={Alexei Baevski and Henry Zhou and Abdelrahman Mohamed and Michael Auli},
......
import torch
from torchaudio_unittest.common_utils import get_asset_path
import pytest
class GreedyCTCDecoder(torch.nn.Module):
def __init__(self, labels):
super().__init__()
self.labels = labels
def forward(self, logits: torch.Tensor) -> str:
"""Given a sequence logits over labels, get the best path string
Args:
logits (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
Returns:
str: The resulting transcript
"""
best_path = torch.argmax(logits, dim=-1) # [num_seq,]
best_path = torch.unique_consecutive(best_path, dim=-1)
hypothesis = []
for i in best_path:
char = self.labels[i]
if char not in ['<s>', '<pad>']:
hypothesis.append(char)
return ''.join(hypothesis)
@pytest.fixture
def ctc_decoder():
return GreedyCTCDecoder
@pytest.fixture
def sample_speech_16000_en():
return get_asset_path('Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac')
import torchaudio
from torchaudio.models import (
HUBERT_BASE,
HUBERT_ASR_LARGE,
)
import pytest
@pytest.mark.parametrize(
"bundle",
[
HUBERT_BASE,
]
)
def test_pretraining_models(bundle):
"""Smoke test of downloading weights for pretraining models"""
bundle.get_model()
@pytest.mark.parametrize(
"bundle,expected",
[
(HUBERT_ASR_LARGE, 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
]
)
def test_finetune_asr_model(
bundle,
expected,
sample_speech_16000_en,
ctc_decoder,
):
"""Smoke test of downloading weights for fine-tuning models and simple transcription"""
model = bundle.get_model().eval()
waveform, sample_rate = torchaudio.load(sample_speech_16000_en)
emission, _ = model(waveform)
decoder = ctc_decoder(bundle.labels)
result = decoder(emission[0])
assert result == expected
......@@ -17,6 +17,11 @@ from .wav2vec2 import (
hubert_ft_large,
hubert_ft_xlarge,
)
from .wav2vec2.pretrained import (
Wav2Vec2PretrainedModelBundle,
HUBERT_BASE,
HUBERT_ASR_LARGE,
)
__all__ = [
'Wav2Letter',
......@@ -36,6 +41,9 @@ __all__ = [
'hubert_xlarge',
'hubert_ft_large',
'hubert_ft_xlarge',
'Wav2Vec2PretrainedModelBundle',
'HUBERT_BASE',
'HUBERT_ASR_LARGE',
'Tacotron2',
'tacotron2',
]
from dataclasses import dataclass
from typing import Dict, Tuple, Any, Optional
from torch.hub import load_state_dict_from_url
from .model import _get_model, Wav2Vec2Model
__all__ = []
@dataclass
class Wav2Vec2PretrainedModelBundle:
"""torchaudio.models.Wav2Vec2PretrainedModelBundle()
Data class that bundles associated information to use pretrained Wav2Vec2Model.
This class provides interfaces for instantiating the pretrained model along with
the information necessary to retrieve pretrained weights and additional data
to be used with the model.
Torchaudio library instantiates objects of this class, each of which represents
a different pretrained model. Client code should access pretrained models via these
instances.
Please see below for the usage and the available values.
Example - Pretraining model
>>> import torchaudio
>>>
>>> # Build the model and load pretrained weight.
>>> model = torchaudio.models.HUBERT_BASE.get_model()
Downloading:
100%|███████████████████████████████| 360M/360M [00:06<00:00, 60.6MB/s]
>>> # Extract acoustic features
>>> waveform, sample_rate = torchaudio.load('my_speech.mp3')
>>> features, _ = model.extract_features(waveform)
Example - Model fine-tuned for ASR
>>> import torchaudio
>>>
>>> # Build the model and load pretrained weight.
>>> model = torchaudio.models.HUBERT_ASR_LARGE.get_model()
Downloading:
100%|███████████████████████████████| 1.18G/1.18G [00:17<00:00, 73.8MB/s]
>>> # Check the corresponding labels of the output.
>>> labels = torchaudio.models.HUBERT_ASR_LARGE.labels
>>> print(labels)
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
>>> # Infer the label probability distribution
>>> waveform, sample_rate = torchaudio.load('my_speech.mp3')
>>> emissions, _ = model(waveform)
>>> # Pass emission to decoder
>>> # `ctc_decode` is for illustration purpose only
>>> transcripts = ctc_decode(emissions, labels)
""" # noqa: E501
_path: str
_params: Dict[str, Any]
_labels: Optional[Tuple[str]]
def get_model(self, *, dl_kwargs=None) -> Wav2Vec2Model:
"""Construct the model and load the pretrained weight.
The weight file is downloaded from the internet and cached with
:func:`torch.hub.load_state_dict_from_url`
Args:
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
"""
model = _get_model(**self._params)
url = f'https://download.pytorch.org/models/audio/{self._path}'
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs)
model.load_state_dict(state_dict)
return model
@property
def labels(self) -> Optional[Tuple[str]]:
"""The optional output class labels (only applicable to ASR bundles)
Returns:
Tuple of strings or None:
For fine-tuned ASR models, returns the tuple of strings representing
the output class labels. For non-ASR models, the value is ``None``.
"""
return self._labels
def _get_labels():
return (
'<s>',
'<pad>',
'</s>',
'<unk>',
'|',
'E',
'T',
'A',
'O',
'N',
'I',
'H',
'S',
'R',
'D',
'L',
'U',
'M',
'W',
'C',
'F',
'G',
'Y',
'P',
'B',
'V',
'K',
"'",
'X',
'J',
'Q',
'Z',
)
HUBERT_BASE = Wav2Vec2PretrainedModelBundle(
'hubert_fairseq_base_ls960.pth',
{
'extractor_mode': 'group_norm',
'extractor_conv_layer_config': [
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 2, 2),
(512, 2, 2),
],
'extractor_conv_bias': False,
'encoder_embed_dim': 768,
'encoder_projection_dropout': 0.1,
'encoder_pos_conv_kernel': 128,
'encoder_pos_conv_groups': 16,
'encoder_num_layers': 12,
'encoder_num_heads': 12,
'encoder_attention_dropout': 0.1,
'encoder_ff_interm_features': 3072,
'encoder_ff_interm_dropout': 0.0,
'encoder_dropout': 0.1,
'encoder_layer_norm_first': False,
'encoder_layer_drop': 0.05,
'aux_num_out': None,
},
_labels=None,
)
HUBERT_BASE.__doc__ = """HuBERT model with "Base" configuration.
Trained on 960 hours of *LibriSpeech* [:footcite:`7178964`] dataset. Not fine-tuned.
Originally published by the authors of *HuBERT* [:footcite:`hsu2021hubert`].
[`Source <https://github.com/pytorch/fairseq/tree/main/examples/hubert#pre-trained-and-fine-tuned-asr-models>`__]
"""
HUBERT_ASR_LARGE = Wav2Vec2PretrainedModelBundle(
'hubert_fairseq_large_ll60k_asr_ls960.pth',
{
'extractor_mode': 'layer_norm',
'extractor_conv_layer_config': [
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 2, 2),
(512, 2, 2),
],
'extractor_conv_bias': False,
'encoder_embed_dim': 1024,
'encoder_projection_dropout': 0.0,
'encoder_pos_conv_kernel': 128,
'encoder_pos_conv_groups': 16,
'encoder_num_layers': 24,
'encoder_num_heads': 16,
'encoder_attention_dropout': 0.0,
'encoder_ff_interm_features': 4096,
'encoder_ff_interm_dropout': 0.1,
'encoder_dropout': 0.0,
'encoder_layer_norm_first': True,
'encoder_layer_drop': 0.1,
'aux_num_out': 32,
},
_labels=_get_labels(),
)
HUBERT_ASR_LARGE.__doc__ = """HuBERT model with "Large" configuration.
Pre-trained on 60,000 hours of *Libri-Light* [:footcite:`librilight`] dataset, and
fine-tuned for ASR on 960 hours of *LibriSpeech* [:footcite:`7178964`] dataset.
Originally published by the authors of *HuBERT* [:footcite:`hsu2021hubert`].
[`Source <https://github.com/pytorch/fairseq/tree/main/examples/hubert#pre-trained-and-fine-tuned-asr-models>`__]
"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment