Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
716aa416
"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "bc4039886d4a1163b6f99912a17d4b82ad00adce"
Unverified
Commit
716aa416
authored
Oct 22, 2021
by
moto
Committed by
GitHub
Oct 22, 2021
Browse files
Add tool to convert voxpopuli model (#1923)
parent
a7161298
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
99 additions
and
0 deletions
+99
-0
tools/convert_voxpopuli_models.py
tools/convert_voxpopuli_models.py
+99
-0
No files found.
tools/convert_voxpopuli_models.py
0 → 100755
View file @
716aa416
#!/usr/bin/env python3
"""Convert the fairseq models available in voxpopuli repo https://github.com/facebookresearch/voxpopuli
The available checkpoints should open with fairseq.
But the following error cannot be resolved with almost any version of fairseq.
https://github.com/facebookresearch/voxpopuli/issues/29
So this script manually parse the checkpoint file and reconstruct the model.
Examples
```
python convert_voxpopuli_models.py
\
--input-file wav2vec2_base_10k_ft_fr.pt
\
--output-file wav2vec2_voxpopuli_base_10k_asr_fr.pt
```
"""
def
_parse_args
():
import
argparse
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
,
formatter_class
=
argparse
.
RawTextHelpFormatter
,
)
parser
.
add_argument
(
'--input-file'
,
required
=
True
,
help
=
'Input checkpoint file.'
)
parser
.
add_argument
(
'--output-file'
,
required
=
False
,
help
=
'Output model file.'
)
return
parser
.
parse_args
()
def
_load
(
input_file
):
import
torch
from
omegaconf
import
OmegaConf
data
=
torch
.
load
(
input_file
)
cfg
=
OmegaConf
.
to_container
(
data
[
'cfg'
])
for
key
in
list
(
cfg
.
keys
()):
if
key
!=
'model'
:
del
cfg
[
key
]
del
cfg
[
'model'
][
'w2v_args'
][
key
]
state_dict
=
{
k
.
removeprefix
(
'w2v_encoder.'
):
v
for
k
,
v
in
data
[
'model'
].
items
()}
return
cfg
,
state_dict
def
_parse_model_param
(
cfg
,
state_dict
):
key_mapping
=
{
"extractor_mode"
:
"extractor_mode"
,
"conv_feature_layers"
:
"extractor_conv_layer_config"
,
"conv_bias"
:
"extractor_conv_bias"
,
"encoder_embed_dim"
:
"encoder_embed_dim"
,
"dropout_input"
:
"encoder_projection_dropout"
,
"conv_pos"
:
"encoder_pos_conv_kernel"
,
"conv_pos_groups"
:
"encoder_pos_conv_groups"
,
"encoder_layers"
:
"encoder_num_layers"
,
"encoder_attention_heads"
:
"encoder_num_heads"
,
"attention_dropout"
:
"encoder_attention_dropout"
,
"encoder_ffn_embed_dim"
:
"encoder_ff_interm_features"
,
"activation_dropout"
:
"encoder_ff_interm_dropout"
,
"dropout"
:
"encoder_dropout"
,
"layer_norm_first"
:
"encoder_layer_norm_first"
,
"layerdrop"
:
"encoder_layer_drop"
,
}
params
=
{}
for
src
,
tgt
in
key_mapping
.
items
():
for
model_cfg
in
[
cfg
[
'model'
],
cfg
[
'model'
][
'w2v_args'
][
'model'
]]:
if
src
in
model_cfg
:
params
[
tgt
]
=
model_cfg
[
src
]
break
if
params
[
"extractor_mode"
]
==
"default"
:
params
[
"extractor_mode"
]
=
"group_norm"
params
[
"extractor_conv_layer_config"
]
=
eval
(
params
[
"extractor_conv_layer_config"
])
assert
len
(
params
)
==
len
(
key_mapping
)
params
[
'aux_num_out'
]
=
state_dict
[
'proj.bias'
].
numel
()
return
params
def
_main
(
args
):
import
json
import
torch
import
torchaudio
from
torchaudio.models.wav2vec2.utils.import_fairseq
import
_convert_state_dict
as
_convert
cfg
,
state_dict
=
_load
(
args
.
input_file
)
params
=
_parse_model_param
(
cfg
,
state_dict
)
print
(
json
.
dumps
(
params
,
indent
=
4
))
model
=
torchaudio
.
models
.
wav2vec2_model
(
**
params
)
model
.
load_state_dict
(
_convert
(
state_dict
))
torch
.
save
(
model
.
state_dict
(),
args
.
output_file
)
if
__name__
==
'__main__'
:
_main
(
_parse_args
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment