Commit b8da16f3 authored by Gunnlaugur Thor Briem's avatar Gunnlaugur Thor Briem
Browse files

Add (failing) tests for Keras save/load

parent ba281707
...@@ -19,8 +19,10 @@ import os ...@@ -19,8 +19,10 @@ import os
import random import random
import tempfile import tempfile
import unittest import unittest
from importlib import import_module
from transformers import is_tf_available, is_torch_available from transformers import is_tf_available, is_torch_available
from transformers.modeling_tf_utils import TFMainLayer
from .utils import _tf_gpu_memory_limit, require_tf from .utils import _tf_gpu_memory_limit, require_tf
...@@ -88,14 +90,45 @@ class TFModelTesterMixin: ...@@ -88,14 +90,45 @@ class TFModelTesterMixin:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname) model = model_class.from_pretrained(tmpdirname)
after_outputs = model(inputs_dict) after_outputs = model(inputs_dict)
self.assert_outputs_same(after_outputs, outputs)
# Make sure we don't have nans def test_keras_save_load(self):
out_1 = after_outputs[0].numpy() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
out_2 = outputs[0].numpy()
out_1 = out_1[~np.isnan(out_1)] tf_main_layer_classes = set(
out_2 = out_2[~np.isnan(out_2)] module_member
max_diff = np.amax(np.abs(out_1 - out_2)) for model_class in self.all_model_classes
self.assertLessEqual(max_diff, 1e-5) for module in (import_module(model_class.__module__),)
for module_member_name in dir(module)
for module_member in (getattr(module, module_member_name),)
if isinstance(module_member, type) and TFMainLayer in module_member.__bases__
)
for main_layer_class in tf_main_layer_classes:
main_layer = main_layer_class(config)
symbolic_inputs = {
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
}
model = tf.keras.Model(symbolic_inputs, outputs=main_layer(symbolic_inputs))
outputs = model(inputs_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
filepath = os.path.join(tmpdirname, "keras_model.h5")
model.save(filepath)
model = tf.keras.models.load_model(
filepath, custom_objects={main_layer_class.__name__: main_layer_class}
)
assert isinstance(model, tf.keras.Model)
after_outputs = model(inputs_dict)
self.assert_outputs_same(after_outputs, outputs)
def assert_outputs_same(self, after_outputs, outputs):
# Make sure we don't have nans
out_1 = after_outputs[0].numpy()
out_2 = outputs[0].numpy()
out_1 = out_1[~np.isnan(out_1)]
out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def test_pt_tf_model_equivalence(self): def test_pt_tf_model_equivalence(self):
if not is_torch_available(): if not is_torch_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment