Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f1433db4
Commit
f1433db4
authored
May 18, 2019
by
Chris
Browse files
update to hf->tf args
parent
077a5b0d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
33 deletions
+20
-33
pytorch_pretrained_bert/convert_hf_checkpoint_to_tf.py
pytorch_pretrained_bert/convert_hf_checkpoint_to_tf.py
+20
-33
No files found.
pytorch_pretrained_bert/convert_hf_checkpoint_to_tf.py
View file @
f1433db4
...
...
@@ -18,16 +18,18 @@
import
os
import
argparse
import
numpy
as
np
import
tensorflow
as
tf
from
pytorch_pretrained_bert.modeling
import
BertConfig
,
BertModel
def
convert_hf_checkpoint_to_tf
(
model
:
BertModel
,
ckpt_dir
:
str
):
def
convert_hf_checkpoint_to_tf
(
model
:
type
(
BertModel
)
,
ckpt_dir
:
str
):
"""
:param model:BertModel Pytorch model instance to be converted
:param ckpt_dir: directory to save Tensorflow model
:return:
S
upported HF models:
Currently s
upported HF models:
Y BertModel
N BertForMaskedLM
N BertForPreTraining
...
...
@@ -35,20 +37,13 @@ def convert_hf_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
N BertForNextSentencePrediction
N BertForSequenceClassification
N BertForQuestionAnswering
Note:
To keep tf out of package-level requirements, it's imported locally.
"""
import
tensorflow
as
tf
if
not
os
.
path
.
isdir
(
ckpt_dir
):
os
.
makedirs
(
ckpt_dir
)
session
=
tf
.
Session
()
state_dict
=
model
.
state_dict
()
tf_vars
=
[]
def
to_tf_var_name
(
name
:
str
):
...
...
@@ -61,6 +56,7 @@ def convert_hf_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
name
=
name
.
replace
(
'LayerNorm/weight'
,
'LayerNorm/gamma'
)
name
=
name
.
replace
(
'LayerNorm/bias'
,
'LayerNorm/beta'
)
name
=
name
.
replace
(
'weight'
,
'kernel'
)
# name += ':0'
return
'bert/{}'
.
format
(
name
)
def
assign_tf_var
(
tensor
:
np
.
ndarray
,
name
:
str
):
...
...
@@ -81,44 +77,35 @@ def convert_hf_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
print
(
"{0}{1}initialized"
.
format
(
tf_name
,
" "
*
(
60
-
len
(
tf_name
))))
saver
=
tf
.
train
.
Saver
(
tf_vars
)
saver
.
save
(
session
,
os
.
path
.
join
(
ckpt_dir
,
'model'
))
saver
.
save
(
session
,
os
.
path
.
join
(
ckpt_dir
,
args
.
pytorch_model_name
))
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--p
retrained_model_name_or_path
"
,
parser
.
add_argument
(
"--p
ytorch_model_dir
"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"pretrained_model_name_or_path: either:
\
- a str with the name of a pre-trained model to load selected in the list of:
\
. `bert-base-uncased`
\
. `bert-large-uncased`
\
. `bert-base-cased`
\
. `bert-large-cased`
\
. `bert-base-multilingual-uncased`
\
. `bert-base-multilingual-cased`
\
. `bert-base-chinese`
\
- a path or url to a pretrained model archive containing:
\
. `bert_config.json` a configuration file for the model
\
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
\
- a path or url to a pretrained model archive containing:
\
. `bert_config.json` a configuration file for the model
\
. `model.ckpt` a TensorFlow checkpoint"
)
parser
.
add_argument
(
"--config_file_path"
,
help
=
"Directory containing pytorch model"
)
parser
.
add_argument
(
"--pytorch_model_name"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"
Path to bert config file.
"
)
parser
.
add_argument
(
"--c
ache_dir
"
,
help
=
"
model name (e.g. bert-base-uncased)
"
)
parser
.
add_argument
(
"--c
onfig_file_path
"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to a folder in which the TF model will be cached."
)
help
=
"Path to bert config file"
)
parser
.
add_argument
(
"--tf_checkpoint_dir"
,
default
=
""
,
type
=
str
,
required
=
True
,
help
=
"Directory in which to save tensorflow model"
)
args
=
parser
.
parse_args
()
model
=
BertModel
(
config
=
BertConfig
(
args
.
config_file_path
)
).
from_pretrained
(
args
.
pretrained_model_name_or_path
)
convert_hf_checkpoint_to_tf
(
model
=
model
,
ckpt_dir
=
args
.
cache_dir
)
\ No newline at end of file
).
from_pretrained
(
args
.
pytorch_model_name
,
cache_dir
=
args
.
pytorch_model_dir
)
convert_hf_checkpoint_to_tf
(
model
=
model
,
ckpt_dir
=
args
.
tf_checkpoint_dir
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment