"...git@developer.sourcefind.cn:chenpangpang/diffusers.git" did not exist on "0922210c5cf2ca0f93fa5f924c6ed195ec7da7e2"
Commit bb04edb4 authored by thomwolf's avatar thomwolf
Browse files

Add tests that TF 2.0 model can be integrated with other Keras modules

parent 6596e3d5
...@@ -22,6 +22,7 @@ import random ...@@ -22,6 +22,7 @@ import random
import shutil import shutil
import unittest import unittest
import uuid import uuid
import tempfile
import pytest import pytest
import sys import sys
...@@ -36,6 +37,20 @@ if is_tf_available(): ...@@ -36,6 +37,20 @@ if is_tf_available():
else: else:
pytestmark = pytest.mark.skip("Require TensorFlow") pytestmark = pytest.mark.skip("Require TensorFlow")
if sys.version_info[0] == 2:
import cPickle as pickle
class TemporaryDirectory(object):
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
def __enter__(self):
self.name = tempfile.mkdtemp()
return self.name
def __exit__(self, exc_type, exc_value, traceback):
shutil.rmtree(self.name)
else:
import pickle
TemporaryDirectory = tempfile.TemporaryDirectory
unicode = str
def _config_zero_init(config): def _config_zero_init(config):
configs_no_init = copy.deepcopy(config) configs_no_init = copy.deepcopy(config)
...@@ -66,13 +81,25 @@ class TFCommonTestCases: ...@@ -66,13 +81,25 @@ class TFCommonTestCases:
# self.assertIn(param.data.mean().item(), [0.0, 1.0], # self.assertIn(param.data.mean().item(), [0.0, 1.0],
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class)) # msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
def test_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
outputs = model(inputs_dict)
with TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
after_outputs = model(inputs_dict)
max_diff = np.amax(np.abs(after_outputs[0].numpy() - outputs[0].numpy()))
self.assertLessEqual(max_diff, 1e-5)
def test_pt_tf_model_equivalence(self): def test_pt_tf_model_equivalence(self):
if not is_torch_available(): if not is_torch_available():
return return
import torch import torch
import numpy as np
import transformers import transformers
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -99,6 +126,34 @@ class TFCommonTestCases: ...@@ -99,6 +126,34 @@ class TFCommonTestCases:
max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy())) max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy()))
self.assertLessEqual(max_diff, 2e-2) self.assertLessEqual(max_diff, 2e-2)
def test_compile_tf_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = tf.keras.Input(batch_shape=(2, 2000), name='input_ids', dtype='int32')
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
for model_class in self.all_model_classes:
# Prepare our model
model = model_class(config)
# Let's load it from the disk to be sure we can use pretrained weights
with TemporaryDirectory() as tmpdirname:
outputs = model(inputs_dict) # build the model
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
outputs_dict = model(input_ids)
hidden_states = outputs_dict[0]
# Add a dense layer on top to test intetgration with other keras modules
outputs = tf.keras.layers.Dense(2, activation='softmax', name='outputs')(hidden_states)
# Compile extended model
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
def test_keyword_and_dict_args(self): def test_keyword_and_dict_args(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment