Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e703e4df
Unverified
Commit
e703e4df
authored
Oct 15, 2019
by
Thomas Wolf
Committed by
GitHub
Oct 15, 2019
Browse files
Merge pull request #1509 from julian-pani/patch-3
remove leftover usage of DUMMY_INPUTS
parents
d147671c
898ce064
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
2 deletions
+23
-2
transformers/modeling_tf_pytorch_utils.py
transformers/modeling_tf_pytorch_utils.py
+1
-1
transformers/tests/modeling_tf_common_test.py
transformers/tests/modeling_tf_common_test.py
+22
-1
No files found.
transformers/modeling_tf_pytorch_utils.py
View file @
e703e4df
...
@@ -198,7 +198,7 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
...
@@ -198,7 +198,7 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
tf_model
=
tf_model_class
(
pt_model
.
config
)
tf_model
=
tf_model_class
(
pt_model
.
config
)
if
tf_inputs
is
None
:
if
tf_inputs
is
None
:
tf_inputs
=
tf
.
constant
(
DUMMY_INPUTS
)
tf_inputs
=
tf
_model
.
dummy_inputs
if
tf_inputs
is
not
None
:
if
tf_inputs
is
not
None
:
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure model is built
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure model is built
...
...
transformers/tests/modeling_tf_common_test.py
View file @
e703e4df
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# limitations under the License.
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
from
__future__
import
absolute_import
,
division
,
print_function
import
os
import
copy
import
copy
import
json
import
json
import
logging
import
logging
...
@@ -118,7 +119,7 @@ class TFCommonTestCases:
...
@@ -118,7 +119,7 @@ class TFCommonTestCases:
tf_model
=
model_class
(
config
)
tf_model
=
model_class
(
config
)
pt_model
=
pt_model_class
(
config
)
pt_model
=
pt_model_class
(
config
)
# Check we can load pt model in tf and vice-versa
(architecture similar)
# Check we can load pt model in tf and vice-versa
with model => model functions
tf_model
=
transformers
.
load_pytorch_model_in_tf2_model
(
tf_model
,
pt_model
,
tf_inputs
=
inputs_dict
)
tf_model
=
transformers
.
load_pytorch_model_in_tf2_model
(
tf_model
,
pt_model
,
tf_inputs
=
inputs_dict
)
pt_model
=
transformers
.
load_tf2_model_in_pytorch_model
(
pt_model
,
tf_model
)
pt_model
=
transformers
.
load_tf2_model_in_pytorch_model
(
pt_model
,
tf_model
)
...
@@ -132,6 +133,26 @@ class TFCommonTestCases:
...
@@ -132,6 +133,26 @@ class TFCommonTestCases:
max_diff
=
np
.
amax
(
np
.
abs
(
tfo
[
0
].
numpy
()
-
pto
[
0
].
numpy
()))
max_diff
=
np
.
amax
(
np
.
abs
(
tfo
[
0
].
numpy
()
-
pto
[
0
].
numpy
()))
self
.
assertLessEqual
(
max_diff
,
2e-2
)
self
.
assertLessEqual
(
max_diff
,
2e-2
)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with
TemporaryDirectory
()
as
tmpdirname
:
pt_checkpoint_path
=
os
.
path
.
join
(
tmpdirname
,
'pt_model.bin'
)
torch
.
save
(
pt_model
.
state_dict
(),
pt_checkpoint_path
)
tf_model
=
transformers
.
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pt_checkpoint_path
)
tf_checkpoint_path
=
os
.
path
.
join
(
tmpdirname
,
'tf_model.h5'
)
tf_model
.
save_weights
(
tf_checkpoint_path
)
pt_model
=
transformers
.
load_tf2_checkpoint_in_pytorch_model
(
pt_model
,
tf_checkpoint_path
)
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model
.
eval
()
pt_inputs_dict
=
dict
((
name
,
torch
.
from_numpy
(
key
.
numpy
()).
to
(
torch
.
long
))
for
name
,
key
in
inputs_dict
.
items
())
with
torch
.
no_grad
():
pto
=
pt_model
(
**
pt_inputs_dict
)
tfo
=
tf_model
(
inputs_dict
)
max_diff
=
np
.
amax
(
np
.
abs
(
tfo
[
0
].
numpy
()
-
pto
[
0
].
numpy
()))
self
.
assertLessEqual
(
max_diff
,
2e-2
)
def
test_compile_tf_model
(
self
):
def
test_compile_tf_model
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment