Unverified Commit 0a375f5a authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Deal with multiple choice in common tests (#4886)

* Deal with multiple choice in common tests
parent e8db8b84
...@@ -407,6 +407,7 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -407,6 +407,7 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
( (
BertModel, BertModel,
BertForMaskedLM, BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction, BertForNextSentencePrediction,
BertForPreTraining, BertForPreTraining,
BertForQuestionAnswering, BertForQuestionAnswering,
......
...@@ -37,6 +37,7 @@ if is_torch_available(): ...@@ -37,6 +37,7 @@ if is_torch_available():
BertModel, BertModel,
BertConfig, BertConfig,
BERT_PRETRAINED_MODEL_ARCHIVE_LIST, BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
top_k_top_p_filtering, top_k_top_p_filtering,
) )
...@@ -62,6 +63,14 @@ class ModelTesterMixin: ...@@ -62,6 +63,14 @@ class ModelTesterMixin:
test_missing_keys = True test_missing_keys = True
is_encoder_decoder = False is_encoder_decoder = False
def _prepare_for_class(self, inputs_dict, model_class):
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
return {
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
for k, v in inputs_dict.items()
}
return inputs_dict
def test_save_load(self): def test_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -70,7 +79,7 @@ class ModelTesterMixin: ...@@ -70,7 +79,7 @@ class ModelTesterMixin:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
out_2 = outputs[0].cpu().numpy() out_2 = outputs[0].cpu().numpy()
out_2[np.isnan(out_2)] = 0 out_2[np.isnan(out_2)] = 0
...@@ -79,7 +88,7 @@ class ModelTesterMixin: ...@@ -79,7 +88,7 @@ class ModelTesterMixin:
model = model_class.from_pretrained(tmpdirname) model = model_class.from_pretrained(tmpdirname)
model.to(torch_device) model.to(torch_device)
with torch.no_grad(): with torch.no_grad():
after_outputs = model(**inputs_dict) after_outputs = model(**self._prepare_for_class(inputs_dict, model_class))
# Make sure we don't have nans # Make sure we don't have nans
out_1 = after_outputs[0].cpu().numpy() out_1 = after_outputs[0].cpu().numpy()
...@@ -109,8 +118,8 @@ class ModelTesterMixin: ...@@ -109,8 +118,8 @@ class ModelTesterMixin:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
first = model(**inputs_dict)[0] first = model(**self._prepare_for_class(inputs_dict, model_class))[0]
second = model(**inputs_dict)[0] second = model(**self._prepare_for_class(inputs_dict, model_class))[0]
out_1 = first.cpu().numpy() out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy() out_2 = second.cpu().numpy()
out_1 = out_1[~np.isnan(out_1)] out_1 = out_1[~np.isnan(out_1)]
...@@ -136,7 +145,7 @@ class ModelTesterMixin: ...@@ -136,7 +145,7 @@ class ModelTesterMixin:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs[-1] attentions = outputs[-1]
self.assertEqual(model.config.output_hidden_states, False) self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
...@@ -189,7 +198,7 @@ class ModelTesterMixin: ...@@ -189,7 +198,7 @@ class ModelTesterMixin:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs)) self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))
self.assertEqual(model.config.output_hidden_states, True) self.assertEqual(model.config.output_hidden_states, True)
...@@ -232,7 +241,7 @@ class ModelTesterMixin: ...@@ -232,7 +241,7 @@ class ModelTesterMixin:
model = model_class(config=configs_no_init) model = model_class(config=configs_no_init)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
inputs = inputs_dict["input_ids"] # Let's keep only input_ids inputs = self._prepare_for_class(inputs_dict, model_class)["input_ids"] # Let's keep only input_ids
try: try:
traced_gpt2 = torch.jit.trace(model, inputs) traced_gpt2 = torch.jit.trace(model, inputs)
...@@ -295,7 +304,7 @@ class ModelTesterMixin: ...@@ -295,7 +304,7 @@ class ModelTesterMixin:
head_mask[0, 0] = 0 head_mask[0, 0] = 0
head_mask[-1, :-1] = 0 head_mask[-1, :-1] = 0
head_mask.requires_grad_(requires_grad=True) head_mask.requires_grad_(requires_grad=True)
inputs = inputs_dict.copy() inputs = self._prepare_for_class(inputs_dict, model_class).copy()
inputs["head_mask"] = head_mask inputs["head_mask"] = head_mask
outputs = model(**inputs) outputs = model(**inputs)
...@@ -346,7 +355,7 @@ class ModelTesterMixin: ...@@ -346,7 +355,7 @@ class ModelTesterMixin:
} }
model.prune_heads(heads_to_prune) model.prune_heads(heads_to_prune)
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs[-1] attentions = outputs[-1]
...@@ -381,7 +390,7 @@ class ModelTesterMixin: ...@@ -381,7 +390,7 @@ class ModelTesterMixin:
model.to(torch_device) model.to(torch_device)
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs[-1] attentions = outputs[-1]
self.assertEqual(attentions[0].shape[-3], 1) self.assertEqual(attentions[0].shape[-3], 1)
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads) self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
...@@ -411,7 +420,7 @@ class ModelTesterMixin: ...@@ -411,7 +420,7 @@ class ModelTesterMixin:
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs[-1] attentions = outputs[-1]
self.assertEqual(attentions[0].shape[-3], 1) self.assertEqual(attentions[0].shape[-3], 1)
...@@ -439,7 +448,7 @@ class ModelTesterMixin: ...@@ -439,7 +448,7 @@ class ModelTesterMixin:
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs[-1] attentions = outputs[-1]
self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1) self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
...@@ -453,7 +462,7 @@ class ModelTesterMixin: ...@@ -453,7 +462,7 @@ class ModelTesterMixin:
model.to(torch_device) model.to(torch_device)
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs[-1] attentions = outputs[-1]
self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1) self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
...@@ -465,7 +474,7 @@ class ModelTesterMixin: ...@@ -465,7 +474,7 @@ class ModelTesterMixin:
model.prune_heads(heads_to_prune) model.prune_heads(heads_to_prune)
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs[-1] attentions = outputs[-1]
self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1) self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
...@@ -484,7 +493,7 @@ class ModelTesterMixin: ...@@ -484,7 +493,7 @@ class ModelTesterMixin:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs[-1] hidden_states = outputs[-1]
self.assertEqual(model.config.output_hidden_states, True) self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1) self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
...@@ -524,7 +533,7 @@ class ModelTesterMixin: ...@@ -524,7 +533,7 @@ class ModelTesterMixin:
# Check that it actually resizes the embeddings matrix # Check that it actually resizes the embeddings matrix
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10) self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
# Check that the model can still do a forward pass successfully (every parameter should be resized) # Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**inputs_dict) model(**self._prepare_for_class(inputs_dict, model_class))
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model_embed = model.resize_token_embeddings(model_vocab_size - 15) model_embed = model.resize_token_embeddings(model_vocab_size - 15)
...@@ -535,7 +544,7 @@ class ModelTesterMixin: ...@@ -535,7 +544,7 @@ class ModelTesterMixin:
# Check that the model can still do a forward pass successfully (every parameter should be resized) # Check that the model can still do a forward pass successfully (every parameter should be resized)
# Input ids should be clamped to the maximum size of the vocabulary # Input ids should be clamped to the maximum size of the vocabulary
inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1) inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1)
model(**inputs_dict) model(**self._prepare_for_class(inputs_dict, model_class))
# Check that adding and removing tokens has not modified the first part of the embedding matrix. # Check that adding and removing tokens has not modified the first part of the embedding matrix.
models_equal = True models_equal = True
...@@ -638,6 +647,8 @@ class ModelTesterMixin: ...@@ -638,6 +647,8 @@ class ModelTesterMixin:
inputs_dict.pop("decoder_input_ids", None) inputs_dict.pop("decoder_input_ids", None)
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
continue
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment