Unverified Commit f1a0b605 authored by moto's avatar moto Committed by GitHub
Browse files

Add wav2vec2 fairseq importer (#1531)

parent 07d9bc21
......@@ -58,3 +58,8 @@ fi
conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} 'librosa>=0.8.0' parameterized 'requests>=2.20'
pip install kaldi-io SoundFile coverage pytest pytest-cov scipy transformers
)
# Install fairseq
git clone https://github.com/pytorch/fairseq
cd fairseq
git checkout e6eddd80
pip install .
......@@ -46,3 +46,8 @@ fi
conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} 'librosa>=0.8.0' parameterized 'requests>=2.20'
pip install kaldi-io SoundFile coverage pytest pytest-cov scipy transformers
)
# Install fairseq
git clone https://github.com/pytorch/fairseq
cd fairseq
git checkout e6eddd80
pip install .
......@@ -62,6 +62,8 @@ Utility Functions
.. autofunction:: import_huggingface_model
.. autofunction:: import_fairseq_model
.. currentmodule:: torchaudio.models
:hidden:`WaveRNN`
......
#!/usr/bin/env python3
"""Generate the conf JSON from fairseq pretrained weight file, that is consumed by unit tests
Usage:
1. Download pretrained parameters from https://github.com/pytorch/fairseq/tree/master/examples/wav2vec
2. Download the dict from https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt
and put it in the same directory as parameter files.
3. Run this script and save the resulting JSON configuration in assets directory.
Example:
```
# Pretrained
python generate_fairseq_model_config.py \
--model-file wav2vec_small.pt \
> wav2vec_small.json
python generate_fairseq_model_config.py \
--model-file libri960_big.pt \
> libri960_big.json
python generate_fairseq_model_config.py \
--model-file wav2vec_vox_new.pt \
> wav2vec_vox_new.json
# Fine-tuned
python generate_fairseq_model_config.py \
--model-file wav2vec_small_960h.pt \
> wav2vec_small_960h.json
python generate_fairseq_model_config.py \
--model-file wav2vec_big_960h.pt \
> wav2vec_large_960h.json
python generate_fairseq_model_config.py \
--model-file wav2vec2_vox_960h_new.pt \
> wav2vec_large_lv60_960h.json
python generate_fairseq_model_config.py \
--model-file wav2vec_vox_960h_pl.pt \
> wav2vec_large_lv60_self_960h.json
```
"""
import os
import json
import argparse
def _parse_args():
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
'--model-file',
required=True,
help=(
'A point file from '
'https://github.com/pytorch/fairseq/tree/master/examples/wav2vec'
)
)
parser.add_argument(
'--dict-dir',
help=(
'Directory where `dict.ltr.txt` file is found. '
'Default: the directory of the given model.'
)
)
args = parser.parse_args()
if args.dict_dir is None:
args.dict_dir = os.path.dirname(args.model_file)
return args
def _to_json(conf):
import yaml
from omegaconf import OmegaConf
return yaml.safe_load(OmegaConf.to_yaml(conf))
def _load(model_file, dict_dir):
import fairseq
overrides = {'data': dict_dir}
_, args, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[model_file], arg_overrides=overrides
)
return _to_json(args['model'])
def _main():
args = _parse_args()
conf = _load(args.model_file, args.dict_dir)
if conf['_name'] == 'wav2vec_ctc':
del conf['data']
del conf['w2v_args']['task']['data']
conf['w2v_args'] = {
key: conf['w2v_args'][key] for key in ['model', 'task']
}
print(json.dumps(conf, indent=4, sort_keys=True))
if __name__ == '__main__':
_main()
{
"_name": "wav2vec2",
"activation_dropout": 0.0,
"activation_fn": "gelu",
"attention_dropout": 0.1,
"codebook_negatives": 0,
"conv_bias": false,
"conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2",
"conv_pos": 128,
"conv_pos_groups": 16,
"cross_sample_negatives": 0,
"dropout": 0.0,
"dropout_features": 0.1,
"dropout_input": 0.1,
"encoder_attention_heads": 16,
"encoder_embed_dim": 1024,
"encoder_ffn_embed_dim": 4096,
"encoder_layerdrop": 0.2,
"encoder_layers": 24,
"extractor_mode": "default",
"feature_grad_mult": 0.1,
"final_dim": 768,
"latent_dim": 0,
"latent_groups": 2,
"latent_temp": [
2.0,
0.5,
0.999995
],
"latent_vars": 320,
"layer_norm_first": false,
"logit_temp": 0.1,
"mask_channel_before": false,
"mask_channel_length": 10,
"mask_channel_min_space": 1,
"mask_channel_other": 0.0,
"mask_channel_prob": 0.0,
"mask_channel_selection": "static",
"mask_length": 10,
"mask_min_space": 1,
"mask_other": 0.0,
"mask_prob": 0.65,
"mask_selection": "static",
"negatives_from_everywhere": false,
"no_mask_channel_overlap": false,
"no_mask_overlap": false,
"num_negatives": 100,
"quantize_input": false,
"quantize_targets": true,
"quantizer_depth": 1,
"quantizer_factor": 3,
"same_quantizer": false,
"target_glu": false
}
{
"_name": "wav2vec_ctc",
"activation_dropout": 0.1,
"apply_mask": true,
"attention_dropout": 0.0,
"blank_mode": "add",
"blank_weight": 0.0,
"conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
"dropout": 0.0,
"dropout_input": 0.0,
"encoder_embed_dim": 512,
"feature_grad_mult": 0.0,
"final_dropout": 0.0,
"freeze_finetune_updates": 10000,
"layerdrop": 0.2,
"mask_channel_before": false,
"mask_channel_length": 64,
"mask_channel_min_space": 1,
"mask_channel_other": 0.0,
"mask_channel_prob": 0.1,
"mask_channel_selection": "static",
"mask_length": 10,
"mask_min_space": 1,
"mask_other": 0.0,
"mask_prob": 0.5,
"mask_selection": "static",
"no_mask_channel_overlap": false,
"no_mask_overlap": false,
"no_pretrained_weights": false,
"normalize": false,
"w2v_args": {
"model": {
"_name": "wav2vec2",
"activation_dropout": 0.0,
"activation_fn": "gelu",
"attention_dropout": 0.1,
"codebook_negatives": 0,
"conv_bias": false,
"conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2",
"conv_pos": 128,
"conv_pos_groups": 16,
"cross_sample_negatives": 0,
"dropout": 0.0,
"dropout_features": 0.1,
"dropout_input": 0.1,
"encoder_attention_heads": 16,
"encoder_embed_dim": 1024,
"encoder_ffn_embed_dim": 4096,
"encoder_layerdrop": 0.2,
"encoder_layers": 24,
"extractor_mode": "default",
"feature_grad_mult": 0.1,
"final_dim": 768,
"latent_dim": 0,
"latent_groups": 2,
"latent_temp": [
2.0,
0.5,
0.999995
],
"latent_vars": 320,
"layer_norm_first": false,
"logit_temp": 0.1,
"mask_channel_before": false,
"mask_channel_length": 10,
"mask_channel_min_space": 1,
"mask_channel_other": 0.0,
"mask_channel_prob": 0.0,
"mask_channel_selection": "static",
"mask_length": 10,
"mask_min_space": 1,
"mask_other": 0.0,
"mask_prob": 0.65,
"mask_selection": "static",
"negatives_from_everywhere": false,
"no_mask_channel_overlap": false,
"no_mask_overlap": false,
"num_negatives": 100,
"quantize_input": false,
"quantize_targets": true,
"quantizer_depth": 1,
"quantizer_factor": 3,
"same_quantizer": false,
"target_glu": false
},
"task": {
"_name": "audio_pretraining",
"autoregressive": false,
"binarized_dataset": false,
"enable_padding": false,
"eval_wer": false,
"eval_wer_config": {
"beam": 5,
"constraints": null,
"decoding_format": null,
"diverse_beam_groups": -1,
"diverse_beam_strength": 0.5,
"diversity_rate": -1.0,
"iter_decode_eos_penalty": 0.0,
"iter_decode_force_max_iter": false,
"iter_decode_max_iter": 10,
"iter_decode_with_beam": 1,
"iter_decode_with_external_reranker": false,
"lenpen": 1.0,
"lm_path": null,
"lm_weight": 0.0,
"match_source_len": false,
"max_len_a": 0.0,
"max_len_b": 200,
"min_len": 1,
"nbest": 1,
"no_beamable_mm": false,
"no_early_stop": false,
"no_repeat_ngram_size": 0,
"no_seed_provided": false,
"prefix_size": 0,
"print_alignment": null,
"print_step": false,
"replace_unk": null,
"retain_dropout": false,
"retain_dropout_modules": null,
"retain_iter_history": false,
"sacrebleu": false,
"sampling": false,
"sampling_topk": -1,
"sampling_topp": -1.0,
"score_reference": false,
"temperature": 1.0,
"unkpen": 0.0,
"unnormalized": false
},
"eval_wer_post_process": "letter",
"eval_wer_tokenizer": null,
"inferred_w2v_config": null,
"labels": null,
"max_sample_size": 320000,
"min_sample_size": 32000,
"normalize": false,
"num_batch_buckets": 0,
"precompute_mask_indices": false,
"sample_rate": 16000,
"tpu": true
}
},
"w2v_path": "???"
}
{
"_name": "wav2vec_ctc",
"activation_dropout": 0.1,
"apply_mask": true,
"attention_dropout": 0.0,
"blank_mode": "add",
"blank_weight": 0.0,
"conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
"dropout": 0.0,
"dropout_input": 0.0,
"encoder_embed_dim": 512,
"feature_grad_mult": 0.0,
"final_dropout": 0.0,
"freeze_finetune_updates": 10000,
"layerdrop": 0.1,
"mask_channel_before": false,
"mask_channel_length": 64,
"mask_channel_min_space": 1,
"mask_channel_other": 0.0,
"mask_channel_prob": 0.25,
"mask_channel_selection": "static",
"mask_length": 10,
"mask_min_space": 1,
"mask_other": 0.0,
"mask_prob": 0.5,
"mask_selection": "static",
"no_mask_channel_overlap": false,
"no_mask_overlap": false,
"no_pretrained_weights": false,
"normalize": true,
"w2v_args": {
"model": {
"_name": "wav2vec2",
"activation_dropout": 0.0,
"activation_fn": "gelu",
"attention_dropout": 0.1,
"codebook_negatives": 0,
"conv_bias": true,
"conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2",
"conv_pos": 128,
"conv_pos_groups": 16,
"cross_sample_negatives": 0,
"dropout": 0.0,
"dropout_features": 0.1,
"dropout_input": 0.1,
"encoder_attention_heads": 16,
"encoder_embed_dim": 1024,
"encoder_ffn_embed_dim": 4096,
"encoder_layerdrop": 0.0,
"encoder_layers": 24,
"extractor_mode": "layer_norm",
"feature_grad_mult": 1.0,
"final_dim": 768,
"latent_dim": 0,
"latent_groups": 2,
"latent_temp": [
2.0,
0.1,
0.999995
],
"latent_vars": 320,
"layer_norm_first": true,
"logit_temp": 0.1,
"mask_channel_before": false,
"mask_channel_length": 10,
"mask_channel_min_space": 1,
"mask_channel_other": 0.0,
"mask_channel_prob": 0.0,
"mask_channel_selection": "static",
"mask_length": 10,
"mask_min_space": 1,
"mask_other": 0.0,
"mask_prob": 0.65,
"mask_selection": "static",
"negatives_from_everywhere": false,
"no_mask_channel_overlap": false,
"no_mask_overlap": false,
"num_negatives": 100,
"quantize_input": false,
"quantize_targets": true,
"quantizer_depth": 1,
"quantizer_factor": 3,
"same_quantizer": false,
"target_glu": false
},
"task": {
"_name": "audio_pretraining",
"autoregressive": false,
"binarized_dataset": false,
"enable_padding": false,
"eval_wer": false,
"eval_wer_config": {
"beam": 5,
"constraints": null,
"decoding_format": null,
"diverse_beam_groups": -1,
"diverse_beam_strength": 0.5,
"diversity_rate": -1.0,
"iter_decode_eos_penalty": 0.0,
"iter_decode_force_max_iter": false,
"iter_decode_max_iter": 10,
"iter_decode_with_beam": 1,
"iter_decode_with_external_reranker": false,
"lenpen": 1.0,
"lm_path": null,
"lm_weight": 0.0,
"match_source_len": false,
"max_len_a": 0.0,
"max_len_b": 200,
"min_len": 1,
"nbest": 1,
"no_beamable_mm": false,
"no_early_stop": false,
"no_repeat_ngram_size": 0,
"no_seed_provided": false,
"prefix_size": 0,
"print_alignment": null,
"print_step": false,
"replace_unk": null,
"retain_dropout": false,
"retain_dropout_modules": null,
"retain_iter_history": false,
"sacrebleu": false,
"sampling": false,
"sampling_topk": -1,
"sampling_topp": -1.0,
"score_reference": false,
"temperature": 1.0,
"unkpen": 0.0,
"unnormalized": false
},
"eval_wer_post_process": "letter",
"eval_wer_tokenizer": null,
"inferred_w2v_config": null,
"labels": null,
"max_sample_size": 320000,
"min_sample_size": 32000,
"normalize": true,
"num_batch_buckets": 0,
"precompute_mask_indices": false,
"sample_rate": 16000,
"tpu": true
}
},
"w2v_path": "???"
}
{
"_name": "wav2vec_ctc",
"activation_dropout": 0.1,
"apply_mask": true,
"attention_dropout": 0.0,
"blank_mode": "add",
"blank_weight": 0.0,
"conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
"dropout": 0.0,
"dropout_input": 0.0,
"encoder_embed_dim": 768,
"feature_grad_mult": 0.0,
"final_dropout": 0.0,
"freeze_finetune_updates": 10000,
"layerdrop": 0.1,
"mask_channel_before": false,
"mask_channel_length": 64,
"mask_channel_min_space": 1,
"mask_channel_other": 0.0,
"mask_channel_prob": 0.1,
"mask_channel_selection": "static",
"mask_length": 10,
"mask_min_space": 1,
"mask_other": 0.0,
"mask_prob": 0.1,
"mask_selection": "static",
"no_mask_channel_overlap": false,
"no_mask_overlap": false,
"no_pretrained_weights": false,
"normalize": true,
"w2v_args": {
"model": {
"_name": "wav2vec2",
"activation_dropout": 0.0,
"activation_fn": "gelu",
"attention_dropout": 0.1,
"codebook_negatives": 0,
"conv_bias": true,
"conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2",
"conv_pos": 128,
"conv_pos_groups": 16,
"cross_sample_negatives": 0,
"dropout": 0.0,
"dropout_features": 0.1,
"dropout_input": 0.1,
"encoder_attention_heads": 16,
"encoder_embed_dim": 1024,
"encoder_ffn_embed_dim": 4096,
"encoder_layerdrop": 0.0,
"encoder_layers": 24,
"extractor_mode": "layer_norm",
"feature_grad_mult": 1.0,
"final_dim": 768,
"latent_dim": 0,
"latent_groups": 2,
"latent_temp": [
2.0,
0.1,
0.999995
],
"latent_vars": 320,
"layer_norm_first": true,
"logit_temp": 0.1,
"mask_channel_before": false,
"mask_channel_length": 10,
"mask_channel_min_space": 1,
"mask_channel_other": 0.0,
"mask_channel_prob": 0.0,
"mask_channel_selection": "static",
"mask_length": 10,
"mask_min_space": 1,
"mask_other": 0.0,
"mask_prob": 0.65,
"mask_selection": "static",
"negatives_from_everywhere": false,
"no_mask_channel_overlap": false,
"no_mask_overlap": false,
"num_negatives": 100,
"quantize_input": false,
"quantize_targets": true,
"quantizer_depth": 1,
"quantizer_factor": 3,
"same_quantizer": false,
"target_glu": false
},
"task": {
"_name": "audio_pretraining",
"autoregressive": false,
"binarized_dataset": false,
"enable_padding": false,
"eval_wer": false,
"eval_wer_config": {
"beam": 5,
"constraints": null,
"decoding_format": null,
"diverse_beam_groups": -1,
"diverse_beam_strength": 0.5,
"diversity_rate": -1.0,
"iter_decode_eos_penalty": 0.0,
"iter_decode_force_max_iter": false,
"iter_decode_max_iter": 10,
"iter_decode_with_beam": 1,
"iter_decode_with_external_reranker": false,
"lenpen": 1.0,
"lm_path": null,
"lm_weight": 0.0,
"match_source_len": false,
"max_len_a": 0.0,
"max_len_b": 200,
"min_len": 1,
"nbest": 1,
"no_beamable_mm": false,
"no_early_stop": false,
"no_repeat_ngram_size": 0,
"no_seed_provided": false,
"prefix_size": 0,
"print_alignment": null,
"print_step": false,
"replace_unk": null,
"retain_dropout": false,
"retain_dropout_modules": null,
"retain_iter_history": false,
"sacrebleu": false,
"sampling": false,
"sampling_topk": -1,
"sampling_topp": -1.0,
"score_reference": false,
"temperature": 1.0,
"unkpen": 0.0,
"unnormalized": false
},
"eval_wer_post_process": "letter",
"eval_wer_tokenizer": null,
"inferred_w2v_config": null,
"labels": null,
"max_sample_size": 320000,
"min_sample_size": 32000,
"normalize": true,
"num_batch_buckets": 0,
"precompute_mask_indices": false,
"sample_rate": 16000,
"tpu": true
}
},
"w2v_path": "/private/home/abaevski/models/wav2vec2/wav2vec_vox_new.pt"
}
{
"_name": "wav2vec2",
"activation_dropout": 0.0,
"activation_fn": "gelu",
"attention_dropout": 0.1,
"codebook_negatives": 0,
"conv_bias": false,
"conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2",
"conv_pos": 128,
"conv_pos_groups": 16,
"cross_sample_negatives": 0,
"dropout": 0.1,
"dropout_features": 0.1,
"dropout_input": 0.1,
"encoder_attention_heads": 12,
"encoder_embed_dim": 768,
"encoder_ffn_embed_dim": 3072,
"encoder_layerdrop": 0.05,
"encoder_layers": 12,
"extractor_mode": "default",
"feature_grad_mult": 0.1,
"final_dim": 256,
"latent_dim": 0,
"latent_groups": 2,
"latent_temp": [
2.0,
0.5,
0.999995
],
"latent_vars": 320,
"layer_norm_first": false,
"logit_temp": 0.1,
"mask_channel_before": false,
"mask_channel_length": 10,
"mask_channel_min_space": 1,
"mask_channel_other": 0.0,
"mask_channel_prob": 0.0,
"mask_channel_selection": "static",
"mask_length": 10,
"mask_min_space": 1,
"mask_other": 0.0,
"mask_prob": 0.65,
"mask_selection": "static",
"negatives_from_everywhere": false,
"no_mask_channel_overlap": false,
"no_mask_overlap": false,
"num_negatives": 100,
"quantize_input": false,
"quantize_targets": true,
"quantizer_depth": 1,
"quantizer_factor": 3,
"same_quantizer": false,
"target_glu": false
}
{
"_name": "wav2vec_ctc",
"activation_dropout": 0.1,
"apply_mask": true,
"attention_dropout": 0.0,
"blank_mode": "add",
"blank_weight": 0.0,
"conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
"dropout": 0.0,
"dropout_input": 0.0,
"encoder_embed_dim": 512,
"feature_grad_mult": 0.0,
"final_dropout": 0.0,
"freeze_finetune_updates": 0,
"layerdrop": 0.1,
"mask_channel_before": false,
"mask_channel_length": 64,
"mask_channel_min_space": 1,
"mask_channel_other": 0.0,
"mask_channel_prob": 0.1,
"mask_channel_selection": "static",
"mask_length": 10,
"mask_min_space": 1,
"mask_other": 0.0,
"mask_prob": 0.5,
"mask_selection": "static",
"no_mask_channel_overlap": false,
"no_mask_overlap": false,
"no_pretrained_weights": false,
"normalize": false,
"w2v_args": {
"model": {
"_name": "wav2vec2",
"activation_dropout": 0.0,
"activation_fn": "gelu",
"attention_dropout": 0.1,
"codebook_negatives": 0,
"conv_bias": false,
"conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2",
"conv_pos": 128,
"conv_pos_groups": 16,
"cross_sample_negatives": 0,
"dropout": 0.1,
"dropout_features": 0.1,
"dropout_input": 0.1,
"encoder_attention_heads": 12,
"encoder_embed_dim": 768,
"encoder_ffn_embed_dim": 3072,
"encoder_layerdrop": 0.05,
"encoder_layers": 12,
"extractor_mode": "default",
"feature_grad_mult": 0.1,
"final_dim": 256,
"latent_dim": 0,
"latent_groups": 2,
"latent_temp": [
2,
0.5,
0.999995
],
"latent_vars": 320,
"layer_norm_first": false,
"logit_temp": 0.1,
"mask_channel_before": false,
"mask_channel_length": 10,
"mask_channel_min_space": 1,
"mask_channel_other": 0.0,
"mask_channel_prob": 0.0,
"mask_channel_selection": "static",
"mask_length": 10,
"mask_min_space": 1,
"mask_other": 0.0,
"mask_prob": 0.65,
"mask_selection": "static",
"negatives_from_everywhere": false,
"no_mask_channel_overlap": false,
"no_mask_overlap": false,
"num_negatives": 100,
"quantize_input": false,
"quantize_targets": true,
"quantizer_depth": 1,
"quantizer_factor": 3,
"same_quantizer": false,
"target_glu": false
},
"task": {
"_name": "audio_pretraining",
"autoregressive": false,
"binarized_dataset": false,
"enable_padding": false,
"eval_wer": false,
"eval_wer_config": {
"beam": 5,
"constraints": null,
"decoding_format": null,
"diverse_beam_groups": -1,
"diverse_beam_strength": 0.5,
"diversity_rate": -1.0,
"iter_decode_eos_penalty": 0.0,
"iter_decode_force_max_iter": false,
"iter_decode_max_iter": 10,
"iter_decode_with_beam": 1,
"iter_decode_with_external_reranker": false,
"lenpen": 1.0,
"lm_path": null,
"lm_weight": 0.0,
"match_source_len": false,
"max_len_a": 0.0,
"max_len_b": 200,
"min_len": 1,
"nbest": 1,
"no_beamable_mm": false,
"no_early_stop": false,
"no_repeat_ngram_size": 0,
"no_seed_provided": false,
"prefix_size": 0,
"print_alignment": null,
"print_step": false,
"replace_unk": null,
"retain_dropout": false,
"retain_dropout_modules": null,
"retain_iter_history": false,
"sacrebleu": false,
"sampling": false,
"sampling_topk": -1,
"sampling_topp": -1.0,
"score_reference": false,
"temperature": 1.0,
"unkpen": 0.0,
"unnormalized": false
},
"eval_wer_post_process": "letter",
"eval_wer_tokenizer": null,
"inferred_w2v_config": null,
"labels": null,
"max_sample_size": 250000,
"min_sample_size": 32000,
"normalize": false,
"num_batch_buckets": 0,
"precompute_mask_indices": false,
"sample_rate": 16000,
"tpu": true
}
},
"w2v_path": "???"
}
{
"_name": "wav2vec2",
"activation_dropout": 0.0,
"activation_fn": "gelu",
"attention_dropout": 0.1,
"codebook_negatives": 0,
"conv_bias": true,
"conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2",
"conv_pos": 128,
"conv_pos_groups": 16,
"cross_sample_negatives": 0,
"dropout": 0.0,
"dropout_features": 0.1,
"dropout_input": 0.1,
"encoder_attention_heads": 16,
"encoder_embed_dim": 1024,
"encoder_ffn_embed_dim": 4096,
"encoder_layerdrop": 0.0,
"encoder_layers": 24,
"extractor_mode": "layer_norm",
"feature_grad_mult": 1.0,
"final_dim": 768,
"latent_dim": 0,
"latent_groups": 2,
"latent_temp": [
2.0,
0.1,
0.999995
],
"latent_vars": 320,
"layer_norm_first": true,
"logit_temp": 0.1,
"mask_channel_before": false,
"mask_channel_length": 10,
"mask_channel_min_space": 1,
"mask_channel_other": 0.0,
"mask_channel_prob": 0.0,
"mask_channel_selection": "static",
"mask_length": 10,
"mask_min_space": 1,
"mask_other": 0.0,
"mask_prob": 0.65,
"mask_selection": "static",
"negatives_from_everywhere": false,
"no_mask_channel_overlap": false,
"no_mask_overlap": false,
"num_negatives": 100,
"quantize_input": false,
"quantize_targets": true,
"quantizer_depth": 1,
"quantizer_factor": 3,
"same_quantizer": false,
"target_glu": false
}
{
"_name": "wav2vec2",
"activation_dropout": 0.0,
"activation_fn": "gelu",
"attention_dropout": 0.0,
"codebook_negatives": 0,
"conv_bias": true,
"conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2",
"conv_pos": 128,
"conv_pos_groups": 16,
"cross_sample_negatives": 0,
"dropout": 0.0,
"dropout_features": 0.0,
"dropout_input": 0.0,
"encoder_attention_heads": 16,
"encoder_embed_dim": 1024,
"encoder_ffn_embed_dim": 4096,
"encoder_layerdrop": 0.0,
"encoder_layers": 24,
"extractor_mode": "layer_norm",
"feature_grad_mult": 1.0,
"final_dim": 768,
"latent_dim": 0,
"latent_groups": 2,
"latent_temp": [
2.0,
0.1,
0.999995
],
"latent_vars": 320,
"layer_norm_first": true,
"logit_temp": 0.1,
"mask_channel_length": 10,
"mask_channel_min_space": 1,
"mask_channel_other": 0.0,
"mask_channel_prob": 0.0,
"mask_channel_selection": "static",
"mask_length": 10,
"mask_min_space": 1,
"mask_other": 0.0,
"mask_prob": 0.65,
"mask_selection": "static",
"negatives_from_everywhere": false,
"no_mask_channel_overlap": false,
"no_mask_overlap": false,
"num_negatives": 100,
"quantize_input": false,
"quantize_targets": true,
"same_quantizer": false,
"target_glu": false
}
import json
import torch
from torchaudio.models.wav2vec2 import (
wav2vec2_base,
wav2vec2_large,
wav2vec2_large_lv60k,
)
from torchaudio.models.wav2vec2.utils import (
import_fairseq_model,
)
from parameterized import parameterized
from torchaudio_unittest.common_utils import (
get_asset_path,
skipIfNoModule,
TorchaudioTestCase,
)
def _load_config(*paths):
with open(f'{get_asset_path("wav2vec2", "fairseq", *paths)}.json', 'r') as file_:
return json.load(file_)
# Pretrined (not fine-tuned) models
BASE = _load_config('wav2vec_small')
LARGE = _load_config('libri960_big')
LARGE_LV60K = _load_config('wav2vec_vox_new')
XLSR_53_56K = _load_config('xlsr_53_56k')
# Fine-tuned models
BASE_960H = _load_config('wav2vec_small_960h')
LARGE_960H = _load_config('wav2vec_large_960h')
LARGE_LV60K_960H = _load_config('wav2vec_large_lv60k_960h')
LARGE_LV60K_SELF_960H = _load_config('wav2vec_large_lv60k_self_960h')
# Config and corresponding factory functions
PRETRAINED_CONFIGS = [
(BASE, wav2vec2_base),
(LARGE, wav2vec2_large),
(LARGE_LV60K, wav2vec2_large_lv60k),
(XLSR_53_56K, wav2vec2_large_lv60k),
]
FINETUNED_CONFIGS = [
(BASE_960H, wav2vec2_base),
(LARGE_960H, wav2vec2_large),
(LARGE_LV60K_960H, wav2vec2_large_lv60k),
(LARGE_LV60K_SELF_960H, wav2vec2_large_lv60k),
]
@skipIfNoModule('fairseq')
class TestFairseqIntegration(TorchaudioTestCase):
"""Test the process of importing the models from fairseq.
Test methods in this test suite check the following things
1. Models loaded with fairseq cane be imported.
2. The same model can be recreated without fairseq.
"""
def _get_model(self, config, num_out):
import copy
from omegaconf import OmegaConf
from fairseq.models.wav2vec.wav2vec2 import (
Wav2Vec2Config,
Wav2Vec2Model,
)
from fairseq.models.wav2vec.wav2vec2_asr import (
Wav2VecEncoder,
Wav2Vec2CtcConfig,
)
if config['_name'] == 'wav2vec_ctc':
config = copy.deepcopy(config)
config['w2v_args'] = OmegaConf.create(config['w2v_args'])
return Wav2VecEncoder(Wav2Vec2CtcConfig(**config), num_out)
if config['_name'] == 'wav2vec2':
return Wav2Vec2Model(Wav2Vec2Config(**config))
@parameterized.expand([conf[:1] for conf in PRETRAINED_CONFIGS])
def test_import_pretrained_model(self, config):
"""Pretrained wav2vec2 models from fairseq can be imported and yields the same results"""
num_out = 28
batch_size, num_frames = 3, 1024
original = self._get_model(config, num_out).eval()
imported = import_fairseq_model(original, 28).eval()
x = torch.randn(batch_size, num_frames)
ref = original.feature_extractor(x).transpose(1, 2)
hyp, _ = imported.extract_features(x)
self.assertEqual(ref, hyp)
@parameterized.expand(PRETRAINED_CONFIGS)
def test_recreate_pretrained_model(self, config, factory_func):
"""Imported pretrained models can be recreated via a factory function without fairseq."""
num_out = 28
batch_size, num_frames = 3, 1024
original = self._get_model(config, num_out).eval()
imported = import_fairseq_model(original, 28).eval()
reloaded = factory_func(num_out=num_out)
reloaded.load_state_dict(imported.state_dict())
reloaded.eval()
x = torch.randn(batch_size, num_frames)
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ])
# Without mask
ref, _ = imported(x)
hyp, _ = reloaded(x)
self.assertEqual(ref, hyp)
# With mask
ref, ref_lengths = imported(x, lengths)
hyp, hyp_lengths = reloaded(x, lengths)
self.assertEqual(ref, hyp)
self.assertEqual(ref_lengths, hyp_lengths)
@parameterized.expand([conf[:1] for conf in FINETUNED_CONFIGS])
def test_import_finetuned_model(self, config):
"""Fintuned wav2vec2 models from fairseq can be imported and yields the same results"""
num_out = 28
batch_size, num_frames = 3, 1024
original = self._get_model(config, num_out).eval()
imported = import_fairseq_model(original).eval()
# Without mask
x = torch.randn(batch_size, num_frames)
ref = original(x, torch.zeros_like(x))['encoder_out'].transpose(0, 1)
hyp, _ = imported(x)
self.assertEqual(ref, hyp)
# With mask
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ])
mask = torch.arange(num_frames).expand(batch_size, num_frames) >= lengths[:, None]
ref = original(x, mask)['encoder_out'].transpose(0, 1)
hyp, output_lengths = imported(x, lengths)
for i, l in enumerate(output_lengths):
self.assertEqual(ref[i, :l, ...], hyp[i, :l, ...])
@parameterized.expand(FINETUNED_CONFIGS)
def test_recreate_finetuned_model(self, config, factory_func):
"""Imported finetuned models can be recreated via a factory function without fairseq."""
num_out = 28
batch_size, num_frames = 3, 1024
original = self._get_model(config, num_out).eval()
imported = import_fairseq_model(original).eval()
reloaded = factory_func(num_out=num_out)
reloaded.load_state_dict(imported.state_dict())
reloaded.eval()
# Without mask
torch.manual_seed(0)
x = torch.randn(batch_size, num_frames)
ref, _ = imported(x)
hyp, _ = reloaded(x)
self.assertEqual(ref, hyp)
# With mask
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ])
ref, ref_lengths = imported(x, lengths)
hyp, hyp_lengths = reloaded(x, lengths)
self.assertEqual(ref, hyp)
self.assertEqual(ref_lengths, hyp_lengths)
from .import_huggingface import import_huggingface_model
from .import_fairseq import import_fairseq_model
__all__ = [
'import_huggingface_model',
'import_fairseq_model',
]
"""Import fariseq's wav2vec2.0 pretrained weights to torchaudios's format.
For this module to work, you need `fairseq`.
"""
import re
from typing import Optional
from torch.nn import Module
from ..model import Wav2Vec2Model, _get_model
def _parse_config(w2v_model, num_out):
encoder = w2v_model.encoder
conv_layers = w2v_model.feature_extractor.conv_layers
extractor_mode = 'layer_norm'
if 'GroupNorm' in conv_layers[0][2].__class__.__name__:
extractor_mode = 'group_norm'
else:
extractor_mode = 'layer_norm'
conv_layer_config = [(l[0].out_channels, l[0].kernel_size[0], l[0].stride[0]) for l in conv_layers]
if all(l[0].bias is None for l in conv_layers):
conv_bias = False
elif all(l[0].bias is not None for l in conv_layers):
conv_bias = True
else:
raise ValueError(
'Either all the convolutions layers have bias term or none of them should.')
config = {
'extractor_mode': extractor_mode,
'extractor_conv_layer_config': conv_layer_config,
'extractor_conv_bias': conv_bias,
'encoder_embed_dim': w2v_model.post_extract_proj.out_features,
'encoder_projection_dropout': w2v_model.dropout_input.p,
'encoder_pos_conv_kernel': encoder.pos_conv[0].kernel_size[0],
'encoder_pos_conv_groups': encoder.pos_conv[0].groups,
'encoder_num_layers': len(encoder.layers),
'encoder_num_heads': encoder.layers[0].self_attn.num_heads,
'encoder_attention_dropout': encoder.layers[0].self_attn.dropout_module.p,
'encoder_ff_interm_features': encoder.layers[0].fc1.out_features,
'encoder_ff_interm_dropout': encoder.layers[0].dropout2.p,
'encoder_dropout': encoder.layers[0].dropout3.p,
'encoder_layer_norm_first': encoder.layer_norm_first,
'encoder_layer_drop': encoder.layerdrop,
'encoder_num_out': num_out,
}
return config
def _map_key(key):
key_ = key
if key.startswith('w2v_model.'):
key = key.replace('w2v_model.', '')
if re.match(r'(mask_emb|quantizer|project_q|final_proj|mask_emb)', key):
return None
# Feature Extractor
# Group norm when "extractor_mode" is "default".
# (Only the first layer)
# "conv_layers.0.2.weight" -> "conv_layers.0.layer_norm.weight"
# "conv_layers.0.2.bias" -> "conv_layers.0.layer_norm.bias"
match = re.match(r'feature_extractor\.conv_layers\.0\.2\.(weight|bias)', key)
if match:
return f"feature_extractor.conv_layers.0.layer_norm.{match.group(1)}"
# Convolutions
# "conv_layers.X.0.weight" -> "conv_layers.X.conv.weight"
# "conv_layers.X.0.bias" -> "conv_layers.X.conv.bias"
match = re.match(r'feature_extractor\.conv_layers\.(\d+)\.0\.(weight|bias)', key)
if match:
return f"feature_extractor.conv_layers.{match.group(1)}.conv.{match.group(2)}"
# Layer norm when "extractor_mode" is "layer_norm".
# "conv_layers.X.2.1.weight" -> "conv_layers.X.layer_norm.weight"
# "conv_layers.X.2.1.bias" -> "conv_layers.X.layer_norm.bias"
match = re.match(r'feature_extractor\.conv_layers\.(\d+)\.2\.1\.(weight|bias)', key)
if match:
return f"feature_extractor.conv_layers.{match.group(1)}.layer_norm.{match.group(2)}"
match = re.match(r"post_extract_proj\.(weight|bias)", key)
# Encoder - Feature projection
if match:
return f"encoder.feature_projection.projection.{match.group(1)}"
match = re.match(r"layer_norm\.(weight|bias)", key)
if match:
return f"encoder.feature_projection.layer_norm.{match.group(1)}"
# Encoder - Transformer - Convolutional positional embedding
match = re.match(r"encoder\.pos_conv\.0\.(bias|weight_g|weight_v)", key)
if match:
return f"encoder.transformer.pos_conv_embed.conv.{match.group(1)}"
match = re.match(r"encoder\.layer_norm\.(weight|bias)", key)
if match:
return f"encoder.transformer.layer_norm.{match.group(1)}"
# Encoder - Transformer - Self attention layers
match = re.match(r"encoder\.layers\.(\d+)\.self_attn\.((k_|v_|q_|out_)proj\.(weight|bias))", key)
if match:
return f"encoder.transformer.layers.{match.group(1)}.attention.{match.group(2)}"
match = re.match(r"encoder\.layers\.(\d+)\.self_attn_layer_norm\.(weight|bias)", key)
if match:
return f"encoder.transformer.layers.{match.group(1)}.layer_norm.{match.group(2)}"
match = re.match(r"encoder\.layers\.(\d+)\.fc1\.(weight|bias)", key)
if match:
return f"encoder.transformer.layers.{match.group(1)}.feed_forward.intermediate_dense.{match.group(2)}"
match = re.match(r"encoder\.layers\.(\d+)\.fc2\.(weight|bias)", key)
if match:
return f"encoder.transformer.layers.{match.group(1)}.feed_forward.output_dense.{match.group(2)}"
match = re.match(r"encoder\.layers\.(\d+)\.final_layer_norm\.(weight|bias)", key)
if match:
return f"encoder.transformer.layers.{match.group(1)}.final_layer_norm.{match.group(2)}"
match = re.match(r"proj\.(weight|bias)", key)
# Encoder - Readout layer
if match:
return f"encoder.readout.{match.group(1)}"
raise ValueError(f'Unexpected key: {key_}')
def _convert_state_dict(state_dict):
converted = {}
for k, v in state_dict.items():
k = _map_key(k)
if k is not None:
converted[k] = v
return converted
def import_fairseq_model(
original: Module,
num_out: Optional[int] = None) -> Wav2Vec2Model:
"""Build Wav2Vec2Model from pretrained parameters published by `fairseq`_.
Args:
original (torch.nn.Module):
An instance of fairseq's Wav2Vec2.0 model class.
Either ``fairseq.models.wav2vec.wav2vec2_asr.Wav2VecEncoder`` or
``fairseq.models.wav2vec.wav2vec2.Wav2Vec2Model``.
num_out (int, optional):
The number of output labels. Required only when the original model is
an instance of ``fairseq.models.wav2vec.wav2vec2.Wav2Vec2Model``.
Returns:
Wav2Vec2Model: Imported model.
Example - Loading pretrain-only model
>>> # Load model using fairseq
>>> model_file = 'wav2vec_small.pt'
>>> model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file])
>>> original = model[0]
>>> imported = import_fairseq_model(original, num_out=28)
>>>
>>> # Perform feature extraction
>>> waveform, _ = torchaudio.load('audio.wav')
>>> features, _ = imported.extract_features(waveform)
>>>
>>> # Compare result with the original model from fairseq
>>> reference = original.feature_extractor(waveform).transpose(1, 2)
>>> torch.testing.assert_allclose(features, reference)
Example - Fine-tuned model
>>> # Load model using fairseq
>>> model_file = 'wav2vec_small_960h.pt'
>>> model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file])
>>> original = model[0]
>>> imported = import_fairseq_model(original.w2v_encoder)
>>>
>>> # Perform encoding
>>> waveform, _ = torchaudio.load('audio.wav')
>>> emission, _ = imported(waveform)
>>>
>>> # Compare result with the original model from fairseq
>>> mask = torch.zeros_like(waveform)
>>> reference = original(waveform, mask)['encoder_out'].transpose(0, 1)
>>> torch.testing.assert_allclose(emission, reference)
.. _fairseq: https://github.com/pytorch/fairseq
"""
class_ = original.__class__.__name__
if class_ == 'Wav2Vec2Model':
if num_out is None:
raise ValueError(
'When importing a pretrained model without readout layer, '
'`num_out` argument must be given.'
)
return _import_pretrained(original, num_out)
if class_ == 'Wav2VecEncoder':
return _import_finetuned(original)
raise ValueError(
f'Expected an instance of `Wav2Vec2Model` or `Wav2VecEncoder`. Found: {class_}')
def _import_finetuned(original: Module) -> Wav2Vec2Model:
config = _parse_config(original.w2v_model, original.proj.out_features)
model = _get_model(**config)
model.load_state_dict(_convert_state_dict(original.state_dict()))
return model
def _import_pretrained(original: Module, num_out: int) -> Wav2Vec2Model:
config = _parse_config(original, num_out)
model = _get_model(**config)
model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment