Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fa0c5a2e
"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "88b86943b550957ffbc838dc63d455c5ead2a1d3"
Commit
fa0c5a2e
authored
Nov 13, 2018
by
lukovnikov
Browse files
clean up pr
parent
f4d79f44
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
39 deletions
+29
-39
convert_tf_checkpoint_to_pytorch.py
convert_tf_checkpoint_to_pytorch.py
+29
-39
No files found.
convert_tf_checkpoint_to_pytorch.py
View file @
fa0c5a2e
...
@@ -26,14 +26,35 @@ import numpy as np
...
@@ -26,14 +26,35 @@ import numpy as np
from
modeling
import
BertConfig
,
BertModel
from
modeling
import
BertConfig
,
BertModel
parser
=
argparse
.
ArgumentParser
()
def
convert
(
config_path
,
ckpt_path
,
out_path
=
None
):
## Required parameters
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--bert_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained BERT model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
args
=
parser
.
parse_args
()
def
convert
():
# Initialise PyTorch model
# Initialise PyTorch model
config
=
BertConfig
.
from_json_file
(
config_
path
)
config
=
BertConfig
.
from_json_file
(
args
.
bert_
config_
file
)
model
=
BertModel
(
config
)
model
=
BertModel
(
config
)
# Load weights from TF model
# Load weights from TF model
path
=
ckp
t_path
path
=
args
.
tf_checkpoin
t_path
print
(
"Converting TensorFlow checkpoint from {}"
.
format
(
path
))
print
(
"Converting TensorFlow checkpoint from {}"
.
format
(
path
))
init_vars
=
tf
.
train
.
list_variables
(
path
)
init_vars
=
tf
.
train
.
list_variables
(
path
)
...
@@ -47,17 +68,11 @@ def convert(config_path, ckpt_path, out_path=None):
...
@@ -47,17 +68,11 @@ def convert(config_path, ckpt_path, out_path=None):
arrays
.
append
(
array
)
arrays
.
append
(
array
)
for
name
,
array
in
zip
(
names
,
arrays
):
for
name
,
array
in
zip
(
names
,
arrays
):
if
not
name
.
startswith
(
"bert"
):
name
=
name
[
5
:]
# skip "bert/"
print
(
"Skipping {}"
.
format
(
name
))
continue
else
:
name
=
name
.
replace
(
"bert/"
,
""
)
# skip "bert/"
print
(
"Loading {}"
.
format
(
name
))
print
(
"Loading {}"
.
format
(
name
))
name
=
name
.
split
(
'/'
)
name
=
name
.
split
(
'/'
)
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
if
name
[
0
]
in
[
'redictions'
,
'eq_relationship'
]:
# which are not required for using pretrained model
print
(
"Skipping"
)
if
name
[
0
]
in
[
'redictions'
,
'eq_relationship'
]
or
name
[
-
1
]
==
"adam_v"
or
name
[
-
1
]
==
"adam_m"
:
print
(
"Skipping {}"
.
format
(
"/"
.
join
(
name
)))
continue
continue
pointer
=
model
pointer
=
model
for
m_name
in
name
:
for
m_name
in
name
:
...
@@ -84,32 +99,7 @@ def convert(config_path, ckpt_path, out_path=None):
...
@@ -84,32 +99,7 @@ def convert(config_path, ckpt_path, out_path=None):
pointer
.
data
=
torch
.
from_numpy
(
array
)
pointer
.
data
=
torch
.
from_numpy
(
array
)
# Save pytorch-model
# Save pytorch-model
if
out_path
is
not
None
:
torch
.
save
(
model
.
state_dict
(),
args
.
pytorch_dump_path
)
torch
.
save
(
model
.
state_dict
(),
out_path
)
return
model
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
convert
()
## Required parameters
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--bert_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained BERT model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_path"
,
default
=
None
,
type
=
str
,
required
=
False
,
help
=
"Path to the output PyTorch model."
)
args
=
parser
.
parse_args
()
print
(
args
)
convert
(
args
.
bert_config_file
,
args
.
tf_checkpoint_path
,
args
.
pytorch_dump_path
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment