Unverified Commit 716aa416 authored by moto's avatar moto Committed by GitHub
Browse files

Add tool to convert voxpopuli model (#1923)

parent a7161298
#!/usr/bin/env python3
"""Convert the fairseq models available in voxpopuli repo https://github.com/facebookresearch/voxpopuli
The available checkpoints should open with fairseq.
But the following error cannot be resolved with almost any version of fairseq.
https://github.com/facebookresearch/voxpopuli/issues/29
So this script manually parse the checkpoint file and reconstruct the model.
Examples
```
python convert_voxpopuli_models.py \
--input-file wav2vec2_base_10k_ft_fr.pt \
--output-file wav2vec2_voxpopuli_base_10k_asr_fr.pt
```
"""
def _parse_args():
import argparse
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
'--input-file', required=True,
help='Input checkpoint file.'
)
parser.add_argument(
'--output-file', required=False,
help='Output model file.'
)
return parser.parse_args()
def _load(input_file):
import torch
from omegaconf import OmegaConf
data = torch.load(input_file)
cfg = OmegaConf.to_container(data['cfg'])
for key in list(cfg.keys()):
if key != 'model':
del cfg[key]
del cfg['model']['w2v_args'][key]
state_dict = {k.removeprefix('w2v_encoder.'): v for k, v in data['model'].items()}
return cfg, state_dict
def _parse_model_param(cfg, state_dict):
key_mapping = {
"extractor_mode": "extractor_mode",
"conv_feature_layers": "extractor_conv_layer_config",
"conv_bias": "extractor_conv_bias",
"encoder_embed_dim": "encoder_embed_dim",
"dropout_input": "encoder_projection_dropout",
"conv_pos": "encoder_pos_conv_kernel",
"conv_pos_groups": "encoder_pos_conv_groups",
"encoder_layers": "encoder_num_layers",
"encoder_attention_heads": "encoder_num_heads",
"attention_dropout": "encoder_attention_dropout",
"encoder_ffn_embed_dim": "encoder_ff_interm_features",
"activation_dropout": "encoder_ff_interm_dropout",
"dropout": "encoder_dropout",
"layer_norm_first": "encoder_layer_norm_first",
"layerdrop": "encoder_layer_drop",
}
params = {}
for src, tgt in key_mapping.items():
for model_cfg in [cfg['model'], cfg['model']['w2v_args']['model']]:
if src in model_cfg:
params[tgt] = model_cfg[src]
break
if params["extractor_mode"] == "default":
params["extractor_mode"] = "group_norm"
params["extractor_conv_layer_config"] = eval(params["extractor_conv_layer_config"])
assert len(params) == len(key_mapping)
params['aux_num_out'] = state_dict['proj.bias'].numel()
return params
def _main(args):
import json
import torch
import torchaudio
from torchaudio.models.wav2vec2.utils.import_fairseq import _convert_state_dict as _convert
cfg, state_dict = _load(args.input_file)
params = _parse_model_param(cfg, state_dict)
print(json.dumps(params, indent=4))
model = torchaudio.models.wav2vec2_model(**params)
model.load_state_dict(_convert(state_dict))
torch.save(model.state_dict(), args.output_file)
if __name__ == '__main__':
_main(_parse_args())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment