#!/usr/bin/env python3 """Convert the fairseq models available in voxpopuli repo https://github.com/facebookresearch/voxpopuli The available checkpoints should open with fairseq. But the following error cannot be resolved with almost any version of fairseq. https://github.com/facebookresearch/voxpopuli/issues/29 So this script manually parse the checkpoint file and reconstruct the model. Examples ``` python convert_voxpopuli_models.py \ --input-file wav2vec2_base_10k_ft_fr.pt \ --output-file wav2vec2_voxpopuli_base_10k_asr_fr.pt ``` """ def _parse_args(): import argparse parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawTextHelpFormatter, ) parser.add_argument( '--input-file', required=True, help='Input checkpoint file.' ) parser.add_argument( '--output-file', required=False, help='Output model file.' ) return parser.parse_args() def _removeprefix(s, prefix): if s.startswith(prefix): return s[len(prefix):] return s def _load(input_file): import torch from omegaconf import OmegaConf data = torch.load(input_file) cfg = OmegaConf.to_container(data['cfg']) for key in list(cfg.keys()): if key != 'model': del cfg[key] if 'w2v_args' in cfg['model']: del cfg['model']['w2v_args'][key] state_dict = {_removeprefix(k, 'w2v_encoder.'): v for k, v in data['model'].items()} return cfg, state_dict def _parse_model_param(cfg, state_dict): key_mapping = { "extractor_mode": "extractor_mode", "conv_feature_layers": "extractor_conv_layer_config", "conv_bias": "extractor_conv_bias", "encoder_embed_dim": "encoder_embed_dim", "dropout_input": "encoder_projection_dropout", "conv_pos": "encoder_pos_conv_kernel", "conv_pos_groups": "encoder_pos_conv_groups", "encoder_layers": "encoder_num_layers", "encoder_attention_heads": "encoder_num_heads", "attention_dropout": "encoder_attention_dropout", "encoder_ffn_embed_dim": "encoder_ff_interm_features", "activation_dropout": "encoder_ff_interm_dropout", "dropout": "encoder_dropout", "layer_norm_first": "encoder_layer_norm_first", "layerdrop": "encoder_layer_drop", "encoder_layerdrop": "encoder_layer_drop", } params = {} src_dicts = [cfg['model']] if 'w2v_args' in cfg['model']: src_dicts.append(cfg['model']['w2v_args']['model']) for src, tgt in key_mapping.items(): for model_cfg in src_dicts: if src in model_cfg: params[tgt] = model_cfg[src] break if params["extractor_mode"] == "default": params["extractor_mode"] = "group_norm" # the following line is commented out to resolve lint warning; uncomment before running script # params["extractor_conv_layer_config"] = eval(params["extractor_conv_layer_config"]) assert len(params) == 15 params['aux_num_out'] = state_dict['proj.bias'].numel() if 'proj.bias' in state_dict else None return params def _main(args): import json import torch import torchaudio from torchaudio.models.wav2vec2.utils.import_fairseq import _convert_state_dict as _convert cfg, state_dict = _load(args.input_file) params = _parse_model_param(cfg, state_dict) print(json.dumps(params, indent=4)) model = torchaudio.models.wav2vec2_model(**params) model.load_state_dict(_convert(state_dict)) torch.save(model.state_dict(), args.output_file) if __name__ == '__main__': _main(_parse_args())