Unverified Commit f7cbc13d authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Test model outputs equivalence (#6445)

* Test model outputs equivalence

* Fix failing tests

* From dict to kwargs

* DistilBERT

* Addressing @sgugger and @patrickvonplaten's comments
parent 54c687e9
...@@ -21,13 +21,17 @@ import tensorflow as tf ...@@ -21,13 +21,17 @@ import tensorflow as tf
from .configuration_longformer import LongformerConfig from .configuration_longformer import LongformerConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_bert import TFBertIntermediate, TFBertOutput, TFBertPooler, TFBertSelfOutput from .modeling_tf_bert import TFBertIntermediate, TFBertOutput, TFBertPooler, TFBertSelfOutput
from .modeling_tf_outputs import TFBaseModelOutputWithPooling, TFMaskedLMOutput, TFQuestionAnsweringModelOutput from .modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPooling,
TFMaskedLMOutput,
TFQuestionAnsweringModelOutput,
)
from .modeling_tf_roberta import TFRobertaEmbeddings, TFRobertaLMHead from .modeling_tf_roberta import TFRobertaEmbeddings, TFRobertaLMHead
from .modeling_tf_utils import ( from .modeling_tf_utils import (
TFMaskedLanguageModelingLoss, TFMaskedLanguageModelingLoss,
TFPreTrainedModel, TFPreTrainedModel,
TFQuestionAnsweringLoss, TFQuestionAnsweringLoss,
cast_bool_to_primitive,
get_initializer, get_initializer,
keras_serializable, keras_serializable,
shape_list, shape_list,
...@@ -833,33 +837,41 @@ class TFLongformerEncoder(tf.keras.layers.Layer): ...@@ -833,33 +837,41 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
TFLongformerLayer(config, i, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers) TFLongformerLayer(config, i, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)
] ]
def call(self, inputs, training=False): def call(
hidden_states, attention_mask, output_attentions, output_hidden_states, padding_len = inputs self,
hidden_states,
attention_mask=None,
head_mask=None,
padding_len=0,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
):
all_hidden_states = () all_hidden_states = () if output_hidden_states else None
all_attentions = () all_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True: if output_hidden_states:
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
all_hidden_states = all_hidden_states + (hidden_states_to_add,) all_hidden_states = all_hidden_states + (hidden_states_to_add,)
layer_outputs = layer_module([hidden_states, attention_mask, output_attentions], training=training) layer_outputs = layer_module([hidden_states, attention_mask, output_attentions], training=training)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if cast_bool_to_primitive(output_attentions, self.output_attentions) is True: if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],) all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer # Add last layer
if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True: if output_hidden_states:
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
all_hidden_states = all_hidden_states + (hidden_states_to_add,) all_hidden_states = all_hidden_states + (hidden_states_to_add,)
outputs = (hidden_states,) if not return_dict:
if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True: return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
outputs = outputs + (all_hidden_states,) return TFBaseModelOutput(
if cast_bool_to_primitive(output_attentions, self.output_attentions) is True: last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
outputs = outputs + (all_attentions,) )
return outputs # outputs, (hidden states), (attentions)
@keras_serializable @keras_serializable
...@@ -992,7 +1004,12 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -992,7 +1004,12 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training) embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
[embedding_output, extended_attention_mask, output_attentions, output_hidden_states, padding_len], embedding_output,
attention_mask=extended_attention_mask,
padding_len=padding_len,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training, training=training,
) )
......
...@@ -18,7 +18,7 @@ import os.path ...@@ -18,7 +18,7 @@ import os.path
import random import random
import tempfile import tempfile
import unittest import unittest
from typing import List from typing import List, Tuple
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import require_multigpu, require_torch, slow, torch_device from transformers.testing_utils import require_multigpu, require_torch, slow, torch_device
...@@ -37,6 +37,11 @@ if is_torch_available(): ...@@ -37,6 +37,11 @@ if is_torch_available():
BERT_PRETRAINED_MODEL_ARCHIVE_LIST, BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING, MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING, MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
top_k_top_p_filtering, top_k_top_p_filtering,
) )
...@@ -63,14 +68,39 @@ class ModelTesterMixin: ...@@ -63,14 +68,39 @@ class ModelTesterMixin:
test_chunking = False test_chunking = False
is_encoder_decoder = False is_encoder_decoder = False
def _prepare_for_class(self, inputs_dict, model_class): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict)
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values(): if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
return { inputs_dict = {
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous() k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
if isinstance(v, torch.Tensor) and v.ndim > 1 if isinstance(v, torch.Tensor) and v.ndim > 1
else v else v
for k, v in inputs_dict.items() for k, v in inputs_dict.items()
} }
if return_labels:
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
elif model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
inputs_dict["start_positions"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
inputs_dict["end_positions"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
elif model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
inputs_dict["labels"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
elif model_class in [
*MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
*MODEL_FOR_CAUSAL_LM_MAPPING.values(),
*MODEL_FOR_MASKED_LM_MAPPING.values(),
*MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
]:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
)
return inputs_dict return inputs_dict
def test_save_load(self): def test_save_load(self):
...@@ -663,6 +693,64 @@ class ModelTesterMixin: ...@@ -663,6 +693,64 @@ class ModelTesterMixin:
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape) # self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head)) # self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
with torch.no_grad():
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(tuple_object, dict_object, atol=1e-5),
msg=f"Tuple and dict output are not equal. Difference: {torch.max(torch.abs(tuple_object - dict_object))}",
)
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(
model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
)
def test_inputs_embeds(self): def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -37,6 +37,8 @@ class T5ModelTester: ...@@ -37,6 +37,8 @@ class T5ModelTester:
self.batch_size = 13 self.batch_size = 13
self.encoder_seq_length = 7 self.encoder_seq_length = 7
self.decoder_seq_length = 9 self.decoder_seq_length = 9
# For common tests
self.seq_length = self.decoder_seq_length
self.is_training = True self.is_training = True
self.use_attention_mask = True self.use_attention_mask = True
self.use_labels = True self.use_labels = True
......
...@@ -21,6 +21,7 @@ import random ...@@ -21,6 +21,7 @@ import random
import tempfile import tempfile
import unittest import unittest
from importlib import import_module from importlib import import_module
from typing import List, Tuple
from transformers import is_tf_available, is_torch_available from transformers import is_tf_available, is_torch_available
from transformers.testing_utils import _tf_gpu_memory_limit, require_tf, slow from transformers.testing_utils import _tf_gpu_memory_limit, require_tf, slow
...@@ -78,6 +79,8 @@ class TFModelTesterMixin: ...@@ -78,6 +79,8 @@ class TFModelTesterMixin:
is_encoder_decoder = False is_encoder_decoder = False
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict)
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values(): if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
inputs_dict = { inputs_dict = {
k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices) + (1,) * (v.ndim - 1)) k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices) + (1,) * (v.ndim - 1))
...@@ -88,20 +91,21 @@ class TFModelTesterMixin: ...@@ -88,20 +91,21 @@ class TFModelTesterMixin:
if return_labels: if return_labels:
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values(): if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
inputs_dict["labels"] = tf.ones(self.model_tester.batch_size) inputs_dict["labels"] = tf.ones(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.values(): elif model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size) inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size) inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values(): elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size) inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(): elif model_class in [
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length)) *TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
elif model_class in TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(): *TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(),
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length)) *TF_MODEL_FOR_MASKED_LM_MAPPING.values(),
elif model_class in TF_MODEL_FOR_MASKED_LM_MAPPING.values(): *TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length)) ]:
elif model_class in TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(): inputs_dict["labels"] = tf.zeros(
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length)) (self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
)
return inputs_dict return inputs_dict
def test_initialization(self): def test_initialization(self):
...@@ -517,6 +521,61 @@ class TFModelTesterMixin: ...@@ -517,6 +521,61 @@ class TFModelTesterMixin:
max_diff = np.amax(np.abs(out_1 - out_2)) max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5) self.assertLessEqual(max_diff, 1e-5)
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
tuple_output = model(tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
all(tf.equal(tuple_object, dict_object)),
msg=f"Tuple and dict output are not equal. Difference: {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}",
)
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes:
model = model_class(config)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(
model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
)
def _get_embeds(self, wte, input_ids): def _get_embeds(self, wte, input_ids):
# ^^ In our TF models, the input_embeddings can take slightly different forms, # ^^ In our TF models, the input_embeddings can take slightly different forms,
# so we try a few of them. # so we try a few of them.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment