Unverified Commit f7cbc13d authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Test model outputs equivalence (#6445)

* Test model outputs equivalence

* Fix failing tests

* From dict to kwargs

* DistilBERT

* Addressing @sgugger and @patrickvonplaten's comments
parent 54c687e9
......@@ -21,13 +21,17 @@ import tensorflow as tf
from .configuration_longformer import LongformerConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_bert import TFBertIntermediate, TFBertOutput, TFBertPooler, TFBertSelfOutput
from .modeling_tf_outputs import TFBaseModelOutputWithPooling, TFMaskedLMOutput, TFQuestionAnsweringModelOutput
from .modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPooling,
TFMaskedLMOutput,
TFQuestionAnsweringModelOutput,
)
from .modeling_tf_roberta import TFRobertaEmbeddings, TFRobertaLMHead
from .modeling_tf_utils import (
TFMaskedLanguageModelingLoss,
TFPreTrainedModel,
TFQuestionAnsweringLoss,
cast_bool_to_primitive,
get_initializer,
keras_serializable,
shape_list,
......@@ -833,33 +837,41 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
TFLongformerLayer(config, i, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)
]
def call(self, inputs, training=False):
hidden_states, attention_mask, output_attentions, output_hidden_states, padding_len = inputs
def call(
self,
hidden_states,
attention_mask=None,
head_mask=None,
padding_len=0,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
):
all_hidden_states = ()
all_attentions = ()
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True:
if output_hidden_states:
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
all_hidden_states = all_hidden_states + (hidden_states_to_add,)
layer_outputs = layer_module([hidden_states, attention_mask, output_attentions], training=training)
hidden_states = layer_outputs[0]
if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer
if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True:
if output_hidden_states:
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
all_hidden_states = all_hidden_states + (hidden_states_to_add,)
outputs = (hidden_states,)
if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True:
outputs = outputs + (all_hidden_states,)
if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:
outputs = outputs + (all_attentions,)
return outputs # outputs, (hidden states), (attentions)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
@keras_serializable
......@@ -992,7 +1004,12 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
encoder_outputs = self.encoder(
[embedding_output, extended_attention_mask, output_attentions, output_hidden_states, padding_len],
embedding_output,
attention_mask=extended_attention_mask,
padding_len=padding_len,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
......
......@@ -18,7 +18,7 @@ import os.path
import random
import tempfile
import unittest
from typing import List
from typing import List, Tuple
from transformers import is_torch_available
from transformers.testing_utils import require_multigpu, require_torch, slow, torch_device
......@@ -37,6 +37,11 @@ if is_torch_available():
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
top_k_top_p_filtering,
)
......@@ -63,14 +68,39 @@ class ModelTesterMixin:
test_chunking = False
is_encoder_decoder = False
def _prepare_for_class(self, inputs_dict, model_class):
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict)
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
return {
inputs_dict = {
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
if isinstance(v, torch.Tensor) and v.ndim > 1
else v
for k, v in inputs_dict.items()
}
if return_labels:
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
elif model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
inputs_dict["start_positions"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
inputs_dict["end_positions"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
elif model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
inputs_dict["labels"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
elif model_class in [
*MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
*MODEL_FOR_CAUSAL_LM_MAPPING.values(),
*MODEL_FOR_MASKED_LM_MAPPING.values(),
*MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
]:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
)
return inputs_dict
def test_save_load(self):
......@@ -663,6 +693,64 @@ class ModelTesterMixin:
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
with torch.no_grad():
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(tuple_object, dict_object, atol=1e-5),
msg=f"Tuple and dict output are not equal. Difference: {torch.max(torch.abs(tuple_object - dict_object))}",
)
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(
model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
)
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
......@@ -37,6 +37,8 @@ class T5ModelTester:
self.batch_size = 13
self.encoder_seq_length = 7
self.decoder_seq_length = 9
# For common tests
self.seq_length = self.decoder_seq_length
self.is_training = True
self.use_attention_mask = True
self.use_labels = True
......
......@@ -21,6 +21,7 @@ import random
import tempfile
import unittest
from importlib import import_module
from typing import List, Tuple
from transformers import is_tf_available, is_torch_available
from transformers.testing_utils import _tf_gpu_memory_limit, require_tf, slow
......@@ -78,6 +79,8 @@ class TFModelTesterMixin:
is_encoder_decoder = False
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict)
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
inputs_dict = {
k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices) + (1,) * (v.ndim - 1))
......@@ -88,20 +91,21 @@ class TFModelTesterMixin:
if return_labels:
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
inputs_dict["labels"] = tf.ones(self.model_tester.batch_size)
inputs_dict["labels"] = tf.ones(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size)
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size)
inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size)
elif model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values():
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length))
elif model_class in TF_MODEL_FOR_CAUSAL_LM_MAPPING.values():
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length))
elif model_class in TF_MODEL_FOR_MASKED_LM_MAPPING.values():
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length))
elif model_class in TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values():
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length))
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in [
*TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
*TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(),
*TF_MODEL_FOR_MASKED_LM_MAPPING.values(),
*TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
]:
inputs_dict["labels"] = tf.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
)
return inputs_dict
def test_initialization(self):
......@@ -517,6 +521,61 @@ class TFModelTesterMixin:
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
tuple_output = model(tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
all(tf.equal(tuple_object, dict_object)),
msg=f"Tuple and dict output are not equal. Difference: {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}",
)
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes:
model = model_class(config)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(
model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
)
def _get_embeds(self, wte, input_ids):
# ^^ In our TF models, the input_embeddings can take slightly different forms,
# so we try a few of them.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment