Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
716cc1c4
Commit
716cc1c4
authored
Jun 19, 2019
by
chrislarson1
Browse files
added main() for programmatic call to convert pytorch->tf
parent
a8e071c6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
23 deletions
+32
-23
pytorch_pretrained_bert/convert_pytorch_checkpoint_to_tf.py
pytorch_pretrained_bert/convert_pytorch_checkpoint_to_tf.py
+32
-23
No files found.
pytorch_pretrained_bert/convert_pytorch_checkpoint_to_tf.py
View file @
716cc1c4
...
@@ -17,16 +17,18 @@
...
@@ -17,16 +17,18 @@
import
os
import
os
import
argparse
import
argparse
import
torch
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
pytorch_pretrained_bert.modeling
import
BertConfig
,
BertModel
from
pytorch_pretrained_bert.modeling
import
BertModel
def
convert_pytorch_checkpoint_to_tf
(
model
:
BertModel
,
ckpt_dir
:
str
):
def
convert_pytorch_checkpoint_to_tf
(
model
:
BertModel
,
ckpt_dir
:
str
,
model_name
:
str
):
"""
"""
:param model:BertModel Pytorch model instance to be converted
:param model:BertModel Pytorch model instance to be converted
:param ckpt_dir: directory to save Tensorflow model
:param ckpt_dir: Tensorflow model directory
:param model_name: model name
:return:
:return:
Currently supported HF models:
Currently supported HF models:
...
@@ -87,35 +89,42 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
...
@@ -87,35 +89,42 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
print
(
"{0}{1}initialized"
.
format
(
tf_name
,
" "
*
(
60
-
len
(
tf_name
))))
print
(
"{0}{1}initialized"
.
format
(
tf_name
,
" "
*
(
60
-
len
(
tf_name
))))
saver
=
tf
.
train
.
Saver
(
tf_vars
)
saver
=
tf
.
train
.
Saver
(
tf_vars
)
saver
.
save
(
session
,
os
.
path
.
join
(
ckpt_dir
,
args
.
pytorch_model_name
))
saver
.
save
(
session
,
os
.
path
.
join
(
ckpt_dir
,
model_name
.
replace
(
"-"
,
"_"
)
+
".ckpt"
))
if
__name__
==
"__main__"
:
def
main
(
raw_args
=
None
):
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--pytorch_model_dir"
,
parser
.
add_argument
(
"--model_name"
,
default
=
None
,
type
=
str
,
required
=
False
,
help
=
"Directory containing pytorch model"
)
parser
.
add_argument
(
"--pytorch_model_name"
,
default
=
None
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
help
=
"model name (e.g. bert-base-uncased)"
)
help
=
"model name e.g. bert-base-uncased"
)
parser
.
add_argument
(
"--config_file_path"
,
parser
.
add_argument
(
"--cache_dir"
,
type
=
str
,
default
=
None
,
default
=
None
,
required
=
False
,
help
=
"Directory containing pytorch model"
)
parser
.
add_argument
(
"--pytorch_model_path"
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
help
=
"Path to bert config file"
)
help
=
"/path/to/<pytorch-model-name>.bin"
)
parser
.
add_argument
(
"--tf_checkpoint_dir"
,
parser
.
add_argument
(
"--tf_cache_dir"
,
default
=
""
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
help
=
"Directory in which to save tensorflow model"
)
help
=
"Directory in which to save tensorflow model"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
(
raw_args
)
model
=
BertModel
.
from_pretrained
(
pretrained_model_name_or_path
=
args
.
model_name
,
state_dict
=
torch
.
load
(
args
.
pytorch_model_path
),
cache_dir
=
args
.
cache_dir
)
model
=
BertModel
(
convert_pytorch_checkpoint_to_tf
(
config
=
BertConfig
(
args
.
config_file_path
)
model
=
model
,
).
from_pretrained
(
args
.
pytorch_model_name
,
cache_dir
=
args
.
pytorch_model_dir
)
ckpt_dir
=
args
.
tf_cache_dir
,
convert_pytorch_checkpoint_to_tf
(
model
=
model
,
ckpt_dir
=
args
.
tf_checkpoint_dir
)
model_name
=
args
.
model_name
)
if
__name__
==
"__main__"
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment