Unverified Commit 24184e73 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Rework some TF tests (#8492)

* Update some tests

* Small update

* Apply style

* Use max_position_embeddings

* Create a fake attribute

* Create a fake attribute

* Update wrong name

* Wrong TransfoXL model file

* Keep the common tests agnostic
parent f6cdafde
...@@ -454,7 +454,7 @@ class TFModelTesterMixin: ...@@ -454,7 +454,7 @@ class TFModelTesterMixin:
def test_compile_tf_model(self): def test_compile_tf_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
max_input = getattr(self.model_tester, "max_position_embeddings", 512)
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0) optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy") metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
...@@ -463,14 +463,16 @@ class TFModelTesterMixin: ...@@ -463,14 +463,16 @@ class TFModelTesterMixin:
if self.is_encoder_decoder: if self.is_encoder_decoder:
input_ids = { input_ids = {
"decoder_input_ids": tf.keras.Input( "decoder_input_ids": tf.keras.Input(
batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32" batch_shape=(2, max_input),
name="decoder_input_ids",
dtype="int32",
), ),
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"), "input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"),
} }
elif model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values(): elif model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
input_ids = tf.keras.Input(batch_shape=(4, 2, 2000), name="input_ids", dtype="int32") input_ids = tf.keras.Input(batch_shape=(4, 2, max_input), name="input_ids", dtype="int32")
else: else:
input_ids = tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32") input_ids = tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32")
# Prepare our model # Prepare our model
model = model_class(config) model = model_class(config)
...@@ -510,70 +512,64 @@ class TFModelTesterMixin: ...@@ -510,70 +512,64 @@ class TFModelTesterMixin:
def test_attention_outputs(self): def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True config.return_dict = True
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", self.model_tester.seq_length) decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", self.model_tester.seq_length)
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", self.model_tester.seq_length) encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", self.model_tester.seq_length)
decoder_key_length = getattr(self.model_tester, "key_length", decoder_seq_length) decoder_key_length = getattr(self.model_tester, "key_length", decoder_seq_length)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
for model_class in self.all_model_classes: def check_decoder_attentions_output(outputs):
inputs_dict["output_attentions"] = True out_len = len(outputs)
inputs_dict["use_cache"] = False self.assertEqual(out_len % 2, 0)
config.output_hidden_states = False decoder_attentions = outputs.decoder_attentions
model = model_class(config) self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
outputs = model(self._prepare_for_class(inputs_dict, model_class)) self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
)
def check_encoder_attentions_output(outputs):
attentions = [ attentions = [
t.numpy() for t in (outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions) t.numpy() for t in (outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions)
] ]
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual( self.assertListEqual(
list(attentions[0].shape[-3:]), list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
) )
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["use_cache"] = False
config.output_hidden_states = False
model = model_class(config)
outputs = model(self._prepare_for_class(inputs_dict, model_class))
out_len = len(outputs) out_len = len(outputs)
self.assertEqual(config.output_hidden_states, False)
check_encoder_attentions_output(outputs)
if self.is_encoder_decoder: if self.is_encoder_decoder:
self.assertEqual(out_len % 2, 0) model = model_class(config)
decoder_attentions = outputs.decoder_attentions outputs = model(self._prepare_for_class(inputs_dict, model_class))
self.assertEqual(model.config.output_hidden_states, False) self.assertEqual(config.output_hidden_states, False)
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers) check_decoder_attentions_output(outputs)
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
)
# Check that output attentions can also be changed via the config # Check that output attentions can also be changed via the config
del inputs_dict["output_attentions"] del inputs_dict["output_attentions"]
config.output_attentions = True config.output_attentions = True
model = model_class(config) model = model_class(config)
outputs = model(self._prepare_for_class(inputs_dict, model_class)) outputs = model(self._prepare_for_class(inputs_dict, model_class))
attentions = [ self.assertEqual(config.output_hidden_states, False)
t.numpy() for t in (outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions) check_encoder_attentions_output(outputs)
]
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
# Check attention is always last and order is fine # Check attention is always last and order is fine
inputs_dict["output_attentions"] = True inputs_dict["output_attentions"] = True
config.output_hidden_states = True config.output_hidden_states = True
model = model_class(config) model = model_class(config)
outputs = model(self._prepare_for_class(inputs_dict, model_class)) outputs = model(self._prepare_for_class(inputs_dict, model_class))
self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs)) self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))
self.assertEqual(model.config.output_hidden_states, True) self.assertEqual(model.config.output_hidden_states, True)
check_encoder_attentions_output(outputs)
attentions = [
t.numpy() for t in (outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions)
]
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
def test_hidden_states_output(self): def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -581,10 +577,12 @@ class TFModelTesterMixin: ...@@ -581,10 +577,12 @@ class TFModelTesterMixin:
def check_hidden_states_output(config, inputs_dict, model_class): def check_hidden_states_output(config, inputs_dict, model_class):
model = model_class(config) model = model_class(config)
outputs = model(self._prepare_for_class(inputs_dict, model_class)) outputs = model(self._prepare_for_class(inputs_dict, model_class))
hidden_states = [t.numpy() for t in outputs[-1]]
expected_num_layers = getattr( expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
) )
hidden_states = outputs[-1]
self.assertEqual(config.output_attentions, False)
self.assertEqual(len(hidden_states), expected_num_layers) self.assertEqual(len(hidden_states), expected_num_layers)
self.assertListEqual( self.assertListEqual(
list(hidden_states[0].shape[-2:]), list(hidden_states[0].shape[-2:]),
......
...@@ -133,23 +133,21 @@ class TFLongformerModelTester: ...@@ -133,23 +133,21 @@ class TFLongformerModelTester:
def create_and_check_longformer_model( def create_and_check_longformer_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.return_dict = True
model = TFLongformerModel(config=config) model = TFLongformerModel(config=config)
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids) result = model(input_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual( self.parent.assertListEqual(
shape_list(result["sequence_output"]), [self.batch_size, self.seq_length, self.hidden_size] shape_list(result.last_hidden_state), [self.batch_size, self.seq_length, self.hidden_size]
) )
self.parent.assertListEqual(shape_list(result["pooled_output"]), [self.batch_size, self.hidden_size]) self.parent.assertListEqual(shape_list(result.pooler_output), [self.batch_size, self.hidden_size])
def create_and_check_longformer_model_with_global_attention_mask( def create_and_check_longformer_model_with_global_attention_mask(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.return_dict = True
model = TFLongformerModel(config=config) model = TFLongformerModel(config=config)
half_input_mask_length = shape_list(input_mask)[-1] // 2 half_input_mask_length = shape_list(input_mask)[-1] // 2
global_attention_mask = tf.concat( global_attention_mask = tf.concat(
...@@ -160,59 +158,43 @@ class TFLongformerModelTester: ...@@ -160,59 +158,43 @@ class TFLongformerModelTester:
axis=-1, axis=-1,
) )
sequence_output, pooled_output = model( result = model(
input_ids, input_ids,
attention_mask=input_mask, attention_mask=input_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
) )
sequence_output, pooled_output = model( result = model(input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask)
input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask result = model(input_ids, global_attention_mask=global_attention_mask)
)
sequence_output, pooled_output = model(input_ids, global_attention_mask=global_attention_mask)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual( self.parent.assertListEqual(
shape_list(result["sequence_output"]), [self.batch_size, self.seq_length, self.hidden_size] shape_list(result.last_hidden_state), [self.batch_size, self.seq_length, self.hidden_size]
) )
self.parent.assertListEqual(shape_list(result["pooled_output"]), [self.batch_size, self.hidden_size]) self.parent.assertListEqual(shape_list(result.pooler_output), [self.batch_size, self.hidden_size])
def create_and_check_longformer_for_masked_lm( def create_and_check_longformer_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.return_dict = True
model = TFLongformerForMaskedLM(config=config) model = TFLongformerForMaskedLM(config=config)
loss, prediction_scores = model( result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels self.parent.assertListEqual(shape_list(result.logits), [self.batch_size, self.seq_length, self.vocab_size])
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
shape_list(result["prediction_scores"]), [self.batch_size, self.seq_length, self.vocab_size]
)
def create_and_check_longformer_for_question_answering( def create_and_check_longformer_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.return_dict = True
model = TFLongformerForQuestionAnswering(config=config) model = TFLongformerForQuestionAnswering(config=config)
loss, start_logits, end_logits = model( result = model(
input_ids, input_ids,
attention_mask=input_mask, attention_mask=input_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
start_positions=sequence_labels, start_positions=sequence_labels,
end_positions=sequence_labels, end_positions=sequence_labels,
) )
result = {
"loss": loss, self.parent.assertListEqual(shape_list(result.start_logits), [self.batch_size, self.seq_length])
"start_logits": start_logits, self.parent.assertListEqual(shape_list(result.end_logits), [self.batch_size, self.seq_length])
"end_logits": end_logits,
}
self.parent.assertListEqual(shape_list(result["start_logits"]), [self.batch_size, self.seq_length])
self.parent.assertListEqual(shape_list(result["end_logits"]), [self.batch_size, self.seq_length])
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment