Unverified Commit 3ca81107 authored by moto's avatar moto Committed by GitHub
Browse files

Tweak wav2vec2 checkpoint conversion tool (#1938)

parent 18685a51
......@@ -34,6 +34,12 @@ def _parse_args():
return parser.parse_args()
def _removeprefix(s, prefix):
if s.startswith(prefix):
return s[len(prefix):]
return s
def _load(input_file):
import torch
from omegaconf import OmegaConf
......@@ -43,9 +49,9 @@ def _load(input_file):
for key in list(cfg.keys()):
if key != 'model':
del cfg[key]
if 'w2v_args' in cfg['model']:
del cfg['model']['w2v_args'][key]
state_dict = {k.removeprefix('w2v_encoder.'): v for k, v in data['model'].items()}
state_dict = {_removeprefix(k, 'w2v_encoder.'): v for k, v in data['model'].items()}
return cfg, state_dict
......@@ -66,18 +72,23 @@ def _parse_model_param(cfg, state_dict):
"dropout": "encoder_dropout",
"layer_norm_first": "encoder_layer_norm_first",
"layerdrop": "encoder_layer_drop",
"encoder_layerdrop": "encoder_layer_drop",
}
params = {}
src_dicts = [cfg['model']]
if 'w2v_args' in cfg['model']:
src_dicts.append(cfg['model']['w2v_args']['model'])
for src, tgt in key_mapping.items():
for model_cfg in [cfg['model'], cfg['model']['w2v_args']['model']]:
for model_cfg in src_dicts:
if src in model_cfg:
params[tgt] = model_cfg[src]
break
if params["extractor_mode"] == "default":
params["extractor_mode"] = "group_norm"
params["extractor_conv_layer_config"] = eval(params["extractor_conv_layer_config"])
assert len(params) == len(key_mapping)
params['aux_num_out'] = state_dict['proj.bias'].numel()
assert len(params) == 15
params['aux_num_out'] = state_dict['proj.bias'].numel() if 'proj.bias' in state_dict else None
return params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment