Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
716cc1c4
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "dd288303273a96ebe38346d048b2900bbd747989"
Commit
716cc1c4
authored
Jun 19, 2019
by
chrislarson1
Browse files
added main() for programmatic call to convert pytorch->tf
parent
a8e071c6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
23 deletions
+32
-23
pytorch_pretrained_bert/convert_pytorch_checkpoint_to_tf.py
pytorch_pretrained_bert/convert_pytorch_checkpoint_to_tf.py
+32
-23
No files found.
pytorch_pretrained_bert/convert_pytorch_checkpoint_to_tf.py
View file @
716cc1c4
...
@@ -17,16 +17,18 @@
...
@@ -17,16 +17,18 @@
import
os
import
os
import
argparse
import
argparse
import
torch
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
pytorch_pretrained_bert.modeling
import
BertConfig
,
BertModel
from
pytorch_pretrained_bert.modeling
import
BertModel
def
convert_pytorch_checkpoint_to_tf
(
model
:
BertModel
,
ckpt_dir
:
str
):
def
convert_pytorch_checkpoint_to_tf
(
model
:
BertModel
,
ckpt_dir
:
str
,
model_name
:
str
):
"""
"""
:param model:BertModel Pytorch model instance to be converted
:param model:BertModel Pytorch model instance to be converted
:param ckpt_dir: directory to save Tensorflow model
:param ckpt_dir: Tensorflow model directory
:param model_name: model name
:return:
:return:
Currently supported HF models:
Currently supported HF models:
...
@@ -87,35 +89,42 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
...
@@ -87,35 +89,42 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
print
(
"{0}{1}initialized"
.
format
(
tf_name
,
" "
*
(
60
-
len
(
tf_name
))))
print
(
"{0}{1}initialized"
.
format
(
tf_name
,
" "
*
(
60
-
len
(
tf_name
))))
saver
=
tf
.
train
.
Saver
(
tf_vars
)
saver
=
tf
.
train
.
Saver
(
tf_vars
)
saver
.
save
(
session
,
os
.
path
.
join
(
ckpt_dir
,
args
.
pytorch_model_name
))
saver
.
save
(
session
,
os
.
path
.
join
(
ckpt_dir
,
model_name
.
replace
(
"-"
,
"_"
)
+
".ckpt"
))
if
__name__
==
"__main__"
:
def
main
(
raw_args
=
None
):
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--pytorch_model_dir"
,
parser
.
add_argument
(
"--model_name"
,
default
=
None
,
type
=
str
,
required
=
False
,
help
=
"Directory containing pytorch model"
)
parser
.
add_argument
(
"--pytorch_model_name"
,
default
=
None
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
help
=
"model name (e.g. bert-base-uncased)"
)
help
=
"model name e.g. bert-base-uncased"
)
parser
.
add_argument
(
"--config_file_path"
,
parser
.
add_argument
(
"--cache_dir"
,
type
=
str
,
default
=
None
,
default
=
None
,
required
=
False
,
help
=
"Directory containing pytorch model"
)
parser
.
add_argument
(
"--pytorch_model_path"
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
help
=
"Path to bert config file"
)
help
=
"/path/to/<pytorch-model-name>.bin"
)
parser
.
add_argument
(
"--tf_checkpoint_dir"
,
parser
.
add_argument
(
"--tf_cache_dir"
,
default
=
""
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
help
=
"Directory in which to save tensorflow model"
)
help
=
"Directory in which to save tensorflow model"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
(
raw_args
)
model
=
BertModel
.
from_pretrained
(
pretrained_model_name_or_path
=
args
.
model_name
,
state_dict
=
torch
.
load
(
args
.
pytorch_model_path
),
cache_dir
=
args
.
cache_dir
)
convert_pytorch_checkpoint_to_tf
(
model
=
model
,
ckpt_dir
=
args
.
tf_cache_dir
,
model_name
=
args
.
model_name
)
model
=
BertModel
(
if
__name__
==
"__main__"
:
config
=
BertConfig
(
args
.
config_file_path
)
main
()
).
from_pretrained
(
args
.
pytorch_model_name
,
cache_dir
=
args
.
pytorch_model_dir
)
convert_pytorch_checkpoint_to_tf
(
model
=
model
,
ckpt_dir
=
args
.
tf_checkpoint_dir
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment