Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
edfd965a
Commit
edfd965a
authored
Jul 26, 2019
by
David Pollack
Browse files
fix convert_to_tf
parent
46cc9dd2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
5 deletions
+5
-5
pytorch_transformers/convert_pytorch_checkpoint_to_tf.py
pytorch_transformers/convert_pytorch_checkpoint_to_tf.py
+5
-5
No files found.
pytorch_transformers/convert_pytorch_checkpoint_to_tf.py
View file @
edfd965a
...
@@ -72,11 +72,11 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s
...
@@ -72,11 +72,11 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:s
return
'bert/{}'
.
format
(
name
)
return
'bert/{}'
.
format
(
name
)
def
assign_tf_var
(
tensor
:
np
.
ndarray
,
name
:
str
):
def
assign_tf_var
(
tensor
:
np
.
ndarray
,
name
:
str
):
t
mp_var
=
tf
.
Variable
(
initial_value
=
tensor
)
t
f_dtype
=
tf
.
dtypes
.
as_dtype
(
tensor
.
dtype
)
tf_var
=
tf
.
get_variable
(
dtype
=
t
mp_var
.
dtype
,
shape
=
t
mp_va
r
.
shape
,
name
=
name
)
tf_var
=
tf
.
get_variable
(
dtype
=
t
f_
dtype
,
shape
=
t
enso
r
.
shape
,
name
=
name
)
op
=
tf
.
assign
(
ref
=
tf
_
var
,
value
=
tmp
_var
)
session
.
run
(
tf
.
var
iables_initializer
([
tf
_var
])
)
session
.
run
(
tf
.
variables_initializer
([
tmp_var
,
tf_var
])
)
tf
.
keras
.
backend
.
set_value
(
tf_var
,
tensor
)
session
.
run
(
fetches
=
[
op
,
tf_var
]
)
session
.
run
(
tf_var
)
return
tf_var
return
tf_var
for
var_name
in
state_dict
:
for
var_name
in
state_dict
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment