Unverified Commit 3ca81107 authored by moto's avatar moto Committed by GitHub
Browse files

Tweak wav2vec2 checkpoint conversion tool (#1938)

parent 18685a51
...@@ -34,6 +34,12 @@ def _parse_args(): ...@@ -34,6 +34,12 @@ def _parse_args():
return parser.parse_args() return parser.parse_args()
def _removeprefix(s, prefix):
if s.startswith(prefix):
return s[len(prefix):]
return s
def _load(input_file): def _load(input_file):
import torch import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
...@@ -43,9 +49,9 @@ def _load(input_file): ...@@ -43,9 +49,9 @@ def _load(input_file):
for key in list(cfg.keys()): for key in list(cfg.keys()):
if key != 'model': if key != 'model':
del cfg[key] del cfg[key]
del cfg['model']['w2v_args'][key] if 'w2v_args' in cfg['model']:
del cfg['model']['w2v_args'][key]
state_dict = {k.removeprefix('w2v_encoder.'): v for k, v in data['model'].items()} state_dict = {_removeprefix(k, 'w2v_encoder.'): v for k, v in data['model'].items()}
return cfg, state_dict return cfg, state_dict
...@@ -66,18 +72,23 @@ def _parse_model_param(cfg, state_dict): ...@@ -66,18 +72,23 @@ def _parse_model_param(cfg, state_dict):
"dropout": "encoder_dropout", "dropout": "encoder_dropout",
"layer_norm_first": "encoder_layer_norm_first", "layer_norm_first": "encoder_layer_norm_first",
"layerdrop": "encoder_layer_drop", "layerdrop": "encoder_layer_drop",
"encoder_layerdrop": "encoder_layer_drop",
} }
params = {} params = {}
src_dicts = [cfg['model']]
if 'w2v_args' in cfg['model']:
src_dicts.append(cfg['model']['w2v_args']['model'])
for src, tgt in key_mapping.items(): for src, tgt in key_mapping.items():
for model_cfg in [cfg['model'], cfg['model']['w2v_args']['model']]: for model_cfg in src_dicts:
if src in model_cfg: if src in model_cfg:
params[tgt] = model_cfg[src] params[tgt] = model_cfg[src]
break break
if params["extractor_mode"] == "default": if params["extractor_mode"] == "default":
params["extractor_mode"] = "group_norm" params["extractor_mode"] = "group_norm"
params["extractor_conv_layer_config"] = eval(params["extractor_conv_layer_config"]) params["extractor_conv_layer_config"] = eval(params["extractor_conv_layer_config"])
assert len(params) == len(key_mapping) assert len(params) == 15
params['aux_num_out'] = state_dict['proj.bias'].numel() params['aux_num_out'] = state_dict['proj.bias'].numel() if 'proj.bias' in state_dict else None
return params return params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment