Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
716aa416
Unverified
Commit
716aa416
authored
Oct 22, 2021
by
moto
Committed by
GitHub
Oct 22, 2021
Browse files
Add tool to convert voxpopuli model (#1923)
parent
a7161298
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
99 additions
and
0 deletions
+99
-0
tools/convert_voxpopuli_models.py
tools/convert_voxpopuli_models.py
+99
-0
No files found.
tools/convert_voxpopuli_models.py
0 → 100755
View file @
716aa416
#!/usr/bin/env python3
"""Convert the fairseq models available in voxpopuli repo https://github.com/facebookresearch/voxpopuli
The available checkpoints should open with fairseq.
But the following error cannot be resolved with almost any version of fairseq.
https://github.com/facebookresearch/voxpopuli/issues/29
So this script manually parse the checkpoint file and reconstruct the model.
Examples
```
python convert_voxpopuli_models.py
\
--input-file wav2vec2_base_10k_ft_fr.pt
\
--output-file wav2vec2_voxpopuli_base_10k_asr_fr.pt
```
"""
def
_parse_args
():
import
argparse
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
,
formatter_class
=
argparse
.
RawTextHelpFormatter
,
)
parser
.
add_argument
(
'--input-file'
,
required
=
True
,
help
=
'Input checkpoint file.'
)
parser
.
add_argument
(
'--output-file'
,
required
=
False
,
help
=
'Output model file.'
)
return
parser
.
parse_args
()
def
_load
(
input_file
):
import
torch
from
omegaconf
import
OmegaConf
data
=
torch
.
load
(
input_file
)
cfg
=
OmegaConf
.
to_container
(
data
[
'cfg'
])
for
key
in
list
(
cfg
.
keys
()):
if
key
!=
'model'
:
del
cfg
[
key
]
del
cfg
[
'model'
][
'w2v_args'
][
key
]
state_dict
=
{
k
.
removeprefix
(
'w2v_encoder.'
):
v
for
k
,
v
in
data
[
'model'
].
items
()}
return
cfg
,
state_dict
def
_parse_model_param
(
cfg
,
state_dict
):
key_mapping
=
{
"extractor_mode"
:
"extractor_mode"
,
"conv_feature_layers"
:
"extractor_conv_layer_config"
,
"conv_bias"
:
"extractor_conv_bias"
,
"encoder_embed_dim"
:
"encoder_embed_dim"
,
"dropout_input"
:
"encoder_projection_dropout"
,
"conv_pos"
:
"encoder_pos_conv_kernel"
,
"conv_pos_groups"
:
"encoder_pos_conv_groups"
,
"encoder_layers"
:
"encoder_num_layers"
,
"encoder_attention_heads"
:
"encoder_num_heads"
,
"attention_dropout"
:
"encoder_attention_dropout"
,
"encoder_ffn_embed_dim"
:
"encoder_ff_interm_features"
,
"activation_dropout"
:
"encoder_ff_interm_dropout"
,
"dropout"
:
"encoder_dropout"
,
"layer_norm_first"
:
"encoder_layer_norm_first"
,
"layerdrop"
:
"encoder_layer_drop"
,
}
params
=
{}
for
src
,
tgt
in
key_mapping
.
items
():
for
model_cfg
in
[
cfg
[
'model'
],
cfg
[
'model'
][
'w2v_args'
][
'model'
]]:
if
src
in
model_cfg
:
params
[
tgt
]
=
model_cfg
[
src
]
break
if
params
[
"extractor_mode"
]
==
"default"
:
params
[
"extractor_mode"
]
=
"group_norm"
params
[
"extractor_conv_layer_config"
]
=
eval
(
params
[
"extractor_conv_layer_config"
])
assert
len
(
params
)
==
len
(
key_mapping
)
params
[
'aux_num_out'
]
=
state_dict
[
'proj.bias'
].
numel
()
return
params
def
_main
(
args
):
import
json
import
torch
import
torchaudio
from
torchaudio.models.wav2vec2.utils.import_fairseq
import
_convert_state_dict
as
_convert
cfg
,
state_dict
=
_load
(
args
.
input_file
)
params
=
_parse_model_param
(
cfg
,
state_dict
)
print
(
json
.
dumps
(
params
,
indent
=
4
))
model
=
torchaudio
.
models
.
wav2vec2_model
(
**
params
)
model
.
load_state_dict
(
_convert
(
state_dict
))
torch
.
save
(
model
.
state_dict
(),
args
.
output_file
)
if
__name__
==
'__main__'
:
_main
(
_parse_args
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment