Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
09ecf225
Commit
09ecf225
authored
Jul 26, 2019
by
David Pollack
Browse files
fixed the fix. tf session madness.
parent
edfd965a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
16 deletions
+16
-16
pytorch_transformers/convert_pytorch_checkpoint_to_tf.py
pytorch_transformers/convert_pytorch_checkpoint_to_tf.py
+16
-16
No files found.
pytorch_transformers/convert_pytorch_checkpoint_to_tf.py
View file @
09ecf225
...
...
@@ -62,33 +62,33 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s
if
not
os
.
path
.
isdir
(
ckpt_dir
):
os
.
makedirs
(
ckpt_dir
)
session
=
tf
.
Session
()
state_dict
=
model
.
state_dict
()
tf_vars
=
[]
def
to_tf_var_name
(
name
:
str
):
for
patt
,
repl
in
iter
(
var_map
):
name
=
name
.
replace
(
patt
,
repl
)
return
'bert/{}'
.
format
(
name
)
def
assign
_tf_var
(
tensor
:
np
.
ndarray
,
name
:
str
):
def
create
_tf_var
(
tensor
:
np
.
ndarray
,
name
:
str
,
session
:
tf
.
Session
):
tf_dtype
=
tf
.
dtypes
.
as_dtype
(
tensor
.
dtype
)
tf_var
=
tf
.
get_variable
(
dtype
=
tf_dtype
,
shape
=
tensor
.
shape
,
name
=
name
)
tf_var
=
tf
.
get_variable
(
dtype
=
tf_dtype
,
shape
=
tensor
.
shape
,
name
=
name
,
initializer
=
tf
.
zeros_initializer
()
)
session
.
run
(
tf
.
variables_initializer
([
tf_var
]))
tf
.
keras
.
backend
.
set_value
(
tf_var
,
tensor
)
session
.
run
(
tf_var
)
return
tf_var
tf
.
reset_default_graph
()
with
tf
.
Session
()
as
session
:
for
var_name
in
state_dict
:
tf_name
=
to_tf_var_name
(
var_name
)
torch_tensor
=
state_dict
[
var_name
].
numpy
()
if
any
([
x
in
var_name
for
x
in
tensors_to_transopse
]):
torch_tensor
=
torch_tensor
.
T
tf_tensor
=
assign_tf_var
(
tensor
=
torch_tensor
,
name
=
tf_name
)
tf_vars
.
append
(
tf_tensor
)
print
(
"{0}{1}initialized"
.
format
(
tf_name
,
" "
*
(
60
-
len
(
tf_name
))))
tf_var
=
create_tf_var
(
tensor
=
torch_tensor
,
name
=
tf_name
,
session
=
session
)
tf
.
keras
.
backend
.
set_value
(
tf_var
,
torch_tensor
)
tf_weight
=
session
.
run
(
tf_var
)
print
(
"Successfully created {}: {}"
.
format
(
tf_name
,
np
.
allclose
(
tf_weight
,
torch_tensor
)))
saver
=
tf
.
train
.
Saver
(
tf
_vars
)
saver
=
tf
.
train
.
Saver
(
tf
.
trainable_variables
()
)
saver
.
save
(
session
,
os
.
path
.
join
(
ckpt_dir
,
model_name
.
replace
(
"-"
,
"_"
)
+
".ckpt"
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment