Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
hehl2
Torchaudio
Commits
3ca81107
"vscode:/vscode.git/clone" did not exist on "a2331a99f2a0a715f63acaa4943684297fbbeb47"
Unverified
Commit
3ca81107
authored
Oct 27, 2021
by
moto
Committed by
GitHub
Oct 27, 2021
Browse files
Tweak wav2vec2 checkpoint conversion tool (#1938)
parent
18685a51
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
6 deletions
+17
-6
tools/convert_voxpopuli_models.py
tools/convert_voxpopuli_models.py
+17
-6
No files found.
tools/convert_voxpopuli_models.py
View file @
3ca81107
...
@@ -34,6 +34,12 @@ def _parse_args():
...
@@ -34,6 +34,12 @@ def _parse_args():
return
parser
.
parse_args
()
return
parser
.
parse_args
()
def
_removeprefix
(
s
,
prefix
):
if
s
.
startswith
(
prefix
):
return
s
[
len
(
prefix
):]
return
s
def
_load
(
input_file
):
def
_load
(
input_file
):
import
torch
import
torch
from
omegaconf
import
OmegaConf
from
omegaconf
import
OmegaConf
...
@@ -43,9 +49,9 @@ def _load(input_file):
...
@@ -43,9 +49,9 @@ def _load(input_file):
for
key
in
list
(
cfg
.
keys
()):
for
key
in
list
(
cfg
.
keys
()):
if
key
!=
'model'
:
if
key
!=
'model'
:
del
cfg
[
key
]
del
cfg
[
key
]
if
'w2v_args'
in
cfg
[
'model'
]:
del
cfg
[
'model'
][
'w2v_args'
][
key
]
del
cfg
[
'model'
][
'w2v_args'
][
key
]
state_dict
=
{
_removeprefix
(
k
,
'w2v_encoder.'
):
v
for
k
,
v
in
data
[
'model'
].
items
()}
state_dict
=
{
k
.
removeprefix
(
'w2v_encoder.'
):
v
for
k
,
v
in
data
[
'model'
].
items
()}
return
cfg
,
state_dict
return
cfg
,
state_dict
...
@@ -66,18 +72,23 @@ def _parse_model_param(cfg, state_dict):
...
@@ -66,18 +72,23 @@ def _parse_model_param(cfg, state_dict):
"dropout"
:
"encoder_dropout"
,
"dropout"
:
"encoder_dropout"
,
"layer_norm_first"
:
"encoder_layer_norm_first"
,
"layer_norm_first"
:
"encoder_layer_norm_first"
,
"layerdrop"
:
"encoder_layer_drop"
,
"layerdrop"
:
"encoder_layer_drop"
,
"encoder_layerdrop"
:
"encoder_layer_drop"
,
}
}
params
=
{}
params
=
{}
src_dicts
=
[
cfg
[
'model'
]]
if
'w2v_args'
in
cfg
[
'model'
]:
src_dicts
.
append
(
cfg
[
'model'
][
'w2v_args'
][
'model'
])
for
src
,
tgt
in
key_mapping
.
items
():
for
src
,
tgt
in
key_mapping
.
items
():
for
model_cfg
in
[
cfg
[
'model'
],
cfg
[
'model'
][
'w2v_args'
][
'model'
]]
:
for
model_cfg
in
src_dicts
:
if
src
in
model_cfg
:
if
src
in
model_cfg
:
params
[
tgt
]
=
model_cfg
[
src
]
params
[
tgt
]
=
model_cfg
[
src
]
break
break
if
params
[
"extractor_mode"
]
==
"default"
:
if
params
[
"extractor_mode"
]
==
"default"
:
params
[
"extractor_mode"
]
=
"group_norm"
params
[
"extractor_mode"
]
=
"group_norm"
params
[
"extractor_conv_layer_config"
]
=
eval
(
params
[
"extractor_conv_layer_config"
])
params
[
"extractor_conv_layer_config"
]
=
eval
(
params
[
"extractor_conv_layer_config"
])
assert
len
(
params
)
==
len
(
key_mapping
)
assert
len
(
params
)
==
15
params
[
'aux_num_out'
]
=
state_dict
[
'proj.bias'
].
numel
()
params
[
'aux_num_out'
]
=
state_dict
[
'proj.bias'
].
numel
()
if
'proj.bias'
in
state_dict
else
None
return
params
return
params
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment