Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b8da16f3
Commit
b8da16f3
authored
Mar 03, 2020
by
Gunnlaugur Thor Briem
Browse files
Add (failing) tests for Keras save/load
parent
ba281707
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
7 deletions
+40
-7
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+40
-7
No files found.
tests/test_modeling_tf_common.py
View file @
b8da16f3
...
...
@@ -19,8 +19,10 @@ import os
import
random
import
tempfile
import
unittest
from
importlib
import
import_module
from
transformers
import
is_tf_available
,
is_torch_available
from
transformers.modeling_tf_utils
import
TFMainLayer
from
.utils
import
_tf_gpu_memory_limit
,
require_tf
...
...
@@ -88,7 +90,38 @@ class TFModelTesterMixin:
model
.
save_pretrained
(
tmpdirname
)
model
=
model_class
.
from_pretrained
(
tmpdirname
)
after_outputs
=
model
(
inputs_dict
)
self
.
assert_outputs_same
(
after_outputs
,
outputs
)
def
test_keras_save_load
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
tf_main_layer_classes
=
set
(
module_member
for
model_class
in
self
.
all_model_classes
for
module
in
(
import_module
(
model_class
.
__module__
),)
for
module_member_name
in
dir
(
module
)
for
module_member
in
(
getattr
(
module
,
module_member_name
),)
if
isinstance
(
module_member
,
type
)
and
TFMainLayer
in
module_member
.
__bases__
)
for
main_layer_class
in
tf_main_layer_classes
:
main_layer
=
main_layer_class
(
config
)
symbolic_inputs
=
{
name
:
tf
.
keras
.
Input
(
tensor
.
shape
[
1
:],
dtype
=
tensor
.
dtype
)
for
name
,
tensor
in
inputs_dict
.
items
()
}
model
=
tf
.
keras
.
Model
(
symbolic_inputs
,
outputs
=
main_layer
(
symbolic_inputs
))
outputs
=
model
(
inputs_dict
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
filepath
=
os
.
path
.
join
(
tmpdirname
,
"keras_model.h5"
)
model
.
save
(
filepath
)
model
=
tf
.
keras
.
models
.
load_model
(
filepath
,
custom_objects
=
{
main_layer_class
.
__name__
:
main_layer_class
}
)
assert
isinstance
(
model
,
tf
.
keras
.
Model
)
after_outputs
=
model
(
inputs_dict
)
self
.
assert_outputs_same
(
after_outputs
,
outputs
)
def
assert_outputs_same
(
self
,
after_outputs
,
outputs
):
# Make sure we don't have nans
out_1
=
after_outputs
[
0
].
numpy
()
out_2
=
outputs
[
0
].
numpy
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment