convert_voxpopuli_models.py 3.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#!/usr/bin/env python3
"""Convert the fairseq models available in voxpopuli repo https://github.com/facebookresearch/voxpopuli

The available checkpoints should open with fairseq.
But the following error cannot be resolved with almost any version of fairseq.
https://github.com/facebookresearch/voxpopuli/issues/29

So this script manually parse the checkpoint file and reconstruct the model.

Examples

```
python convert_voxpopuli_models.py \
  --input-file wav2vec2_base_10k_ft_fr.pt \
  --output-file wav2vec2_voxpopuli_base_10k_asr_fr.pt
```
"""


def _parse_args():
    import argparse
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawTextHelpFormatter,
    )
    parser.add_argument(
        '--input-file', required=True,
        help='Input checkpoint file.'
    )
    parser.add_argument(
        '--output-file', required=False,
        help='Output model file.'
    )
    return parser.parse_args()


37
38
39
40
41
42
def _removeprefix(s, prefix):
    if s.startswith(prefix):
        return s[len(prefix):]
    return s


43
44
45
46
47
48
49
50
51
def _load(input_file):
    import torch
    from omegaconf import OmegaConf

    data = torch.load(input_file)
    cfg = OmegaConf.to_container(data['cfg'])
    for key in list(cfg.keys()):
        if key != 'model':
            del cfg[key]
52
53
54
            if 'w2v_args' in cfg['model']:
                del cfg['model']['w2v_args'][key]
    state_dict = {_removeprefix(k, 'w2v_encoder.'): v for k, v in data['model'].items()}
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    return cfg, state_dict


def _parse_model_param(cfg, state_dict):
    key_mapping = {
        "extractor_mode": "extractor_mode",
        "conv_feature_layers": "extractor_conv_layer_config",
        "conv_bias": "extractor_conv_bias",
        "encoder_embed_dim": "encoder_embed_dim",
        "dropout_input": "encoder_projection_dropout",
        "conv_pos": "encoder_pos_conv_kernel",
        "conv_pos_groups": "encoder_pos_conv_groups",
        "encoder_layers": "encoder_num_layers",
        "encoder_attention_heads": "encoder_num_heads",
        "attention_dropout": "encoder_attention_dropout",
        "encoder_ffn_embed_dim": "encoder_ff_interm_features",
        "activation_dropout": "encoder_ff_interm_dropout",
        "dropout": "encoder_dropout",
        "layer_norm_first": "encoder_layer_norm_first",
        "layerdrop": "encoder_layer_drop",
75
        "encoder_layerdrop": "encoder_layer_drop",
76
77
    }
    params = {}
78
79
80
81
    src_dicts = [cfg['model']]
    if 'w2v_args' in cfg['model']:
        src_dicts.append(cfg['model']['w2v_args']['model'])

82
    for src, tgt in key_mapping.items():
83
        for model_cfg in src_dicts:
84
85
86
87
88
            if src in model_cfg:
                params[tgt] = model_cfg[src]
                break
    if params["extractor_mode"] == "default":
        params["extractor_mode"] = "group_norm"
89
90
    # the following line is commented out to resolve lint warning; uncomment before running script
    # params["extractor_conv_layer_config"] = eval(params["extractor_conv_layer_config"])
91
92
    assert len(params) == 15
    params['aux_num_out'] = state_dict['proj.bias'].numel() if 'proj.bias' in state_dict else None
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    return params


def _main(args):
    import json
    import torch
    import torchaudio
    from torchaudio.models.wav2vec2.utils.import_fairseq import _convert_state_dict as _convert

    cfg, state_dict = _load(args.input_file)
    params = _parse_model_param(cfg, state_dict)
    print(json.dumps(params, indent=4))
    model = torchaudio.models.wav2vec2_model(**params)
    model.load_state_dict(_convert(state_dict))
    torch.save(model.state_dict(), args.output_file)


if __name__ == '__main__':
    _main(_parse_args())