Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2bcda8d0
Commit
2bcda8d0
authored
May 18, 2019
by
Chris
Browse files
update
parent
41089bc7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1 addition
and
8 deletions
+1
-8
pytorch_pretrained_bert/convert_hf_checkpoint_to_tf.py
pytorch_pretrained_bert/convert_hf_checkpoint_to_tf.py
+1
-8
No files found.
pytorch_pretrained_bert/convert_hf_checkpoint_to_tf.py
View file @
2bcda8d0
...
...
@@ -37,7 +37,7 @@ def convert_hf_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
N BertForQuestionAnswering
Note:
To keep
TF
out of package-level requirements,
tf i
s imported locally.
To keep
tf
out of package-level requirements,
it'
s imported locally.
"""
import
tensorflow
as
tf
...
...
@@ -52,9 +52,7 @@ def convert_hf_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
tf_vars
=
[]
def
to_tf_var_name
(
name
:
str
):
"""todo: compile as regex"""
name
=
name
.
replace
(
'layer.'
,
'layer_'
)
name
=
name
.
replace
(
'word_embeddings.weight'
,
'word_embeddings'
)
name
=
name
.
replace
(
'position_embeddings.weight'
,
'position_embeddings'
)
...
...
@@ -74,17 +72,12 @@ def convert_hf_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
return
tf_var
for
var_name
in
state_dict
:
tf_name
=
to_tf_var_name
(
var_name
)
torch_tensor
=
state_dict
[
var_name
].
numpy
()
if
var_name
.
endswith
(
'dense.weight'
):
torch_tensor
=
torch_tensor
.
T
tf_tensor
=
assign_tf_var
(
tensor
=
torch_tensor
,
name
=
tf_name
)
tf_vars
.
append
(
tf_tensor
)
print
(
"{0}{1}initialized"
.
format
(
tf_name
,
" "
*
(
60
-
len
(
tf_name
))))
saver
=
tf
.
train
.
Saver
(
tf_vars
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment