Unverified Commit 3b43b018 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #1482 from huggingface/tf2_integration_tests

Integration of TF 2.0 models with other Keras modules
parents 700331b5 4b8f3e8f
......@@ -22,6 +22,7 @@ import random
import shutil
import unittest
import uuid
import tempfile
import pytest
import sys
......@@ -36,6 +37,20 @@ if is_tf_available():
else:
pytestmark = pytest.mark.skip("Require TensorFlow")
if sys.version_info[0] == 2:
import cPickle as pickle
class TemporaryDirectory(object):
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
def __enter__(self):
self.name = tempfile.mkdtemp()
return self.name
def __exit__(self, exc_type, exc_value, traceback):
shutil.rmtree(self.name)
else:
import pickle
TemporaryDirectory = tempfile.TemporaryDirectory
unicode = str
def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)
......@@ -66,13 +81,31 @@ class TFCommonTestCases:
# self.assertIn(param.data.mean().item(), [0.0, 1.0],
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
def test_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
outputs = model(inputs_dict)
with TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
after_outputs = model(inputs_dict)
# Make sure we don't have nans
out_1 = after_outputs[0].numpy()
out_2 = outputs[0].numpy()
out_1 = out_1[~np.isnan(out_1)]
out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def test_pt_tf_model_equivalence(self):
if not is_torch_available():
return
import torch
import numpy as np
import transformers
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......@@ -99,6 +132,34 @@ class TFCommonTestCases:
max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy()))
self.assertLessEqual(max_diff, 2e-2)
def test_compile_tf_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = tf.keras.Input(batch_shape=(2, 2000), name='input_ids', dtype='int32')
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
for model_class in self.all_model_classes:
# Prepare our model
model = model_class(config)
# Let's load it from the disk to be sure we can use pretrained weights
with TemporaryDirectory() as tmpdirname:
outputs = model(inputs_dict) # build the model
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
outputs_dict = model(input_ids)
hidden_states = outputs_dict[0]
# Add a dense layer on top to test intetgration with other keras modules
outputs = tf.keras.layers.Dense(2, activation='softmax', name='outputs')(hidden_states)
# Compile extended model
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
def test_keyword_and_dict_args(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
......@@ -161,6 +161,11 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
"outputs": outputs.numpy(),
}
config.mem_len = 0
model = TFXLNetModel(config)
no_mems_outputs = model(inputs)
self.parent.assertEqual(len(no_mems_outputs), 1)
self.parent.assertListEqual(
list(result["outputs"].shape),
[self.batch_size, self.seq_length, self.hidden_size])
......
......@@ -150,6 +150,12 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
"outputs": outputs,
}
config.mem_len = 0
model = XLNetModel(config)
model.eval()
no_mems_outputs = model(input_ids_1)
self.parent.assertEqual(len(no_mems_outputs), 1)
self.parent.assertListEqual(
list(result["outputs"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment