Unverified Commit d951c14a authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Model output test (#6155)

* Use return_dict=True in all tests

* Formatting
parent 86caab1e
...@@ -273,6 +273,7 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -273,6 +273,7 @@ class EncoderDecoderModel(PreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
head_mask=head_mask, head_mask=head_mask,
return_dict=False,
**kwargs_encoder, **kwargs_encoder,
) )
...@@ -287,6 +288,7 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -287,6 +288,7 @@ class EncoderDecoderModel(PreTrainedModel):
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask, head_mask=decoder_head_mask,
labels=labels, labels=labels,
return_dict=False,
**kwargs_decoder, **kwargs_decoder,
) )
......
...@@ -688,16 +688,15 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -688,16 +688,15 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
lm_loss = None lm_loss, mc_loss = None, None
if mc_labels is not None: if mc_labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
mc_loss = None
if labels is not None: if labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous() shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
mc_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict: if not return_dict:
output = (lm_logits, mc_logits) + transformer_outputs[1:] output = (lm_logits, mc_logits) + transformer_outputs[1:]
......
...@@ -2386,6 +2386,7 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel): ...@@ -2386,6 +2386,7 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.reformer( outputs = self.reformer(
input_ids, input_ids,
......
...@@ -121,6 +121,7 @@ class XxxModelTester: ...@@ -121,6 +121,7 @@ class XxxModelTester:
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
return_dict=True,
) )
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -134,18 +135,13 @@ class XxxModelTester: ...@@ -134,18 +135,13 @@ class XxxModelTester:
model = XxxModel(config=config) model = XxxModel(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids) result = model(input_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size] list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
) )
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_xxx_for_masked_lm( def create_and_check_xxx_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -153,16 +149,10 @@ class XxxModelTester: ...@@ -153,16 +149,10 @@ class XxxModelTester:
model = XxxForMaskedLM(config=config) model = XxxForMaskedLM(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, prediction_scores = model( result = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels
) )
result = { self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_xxx_for_question_answering( def create_and_check_xxx_for_question_answering(
...@@ -171,18 +161,13 @@ class XxxModelTester: ...@@ -171,18 +161,13 @@ class XxxModelTester:
model = XxxForQuestionAnswering(config=config) model = XxxForQuestionAnswering(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, start_logits, end_logits = model( result = model(
input_ids, input_ids,
attention_mask=input_mask, attention_mask=input_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
start_positions=sequence_labels, start_positions=sequence_labels,
end_positions=sequence_labels, end_positions=sequence_labels,
) )
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -194,13 +179,7 @@ class XxxModelTester: ...@@ -194,13 +179,7 @@ class XxxModelTester:
model = XxxForSequenceClassification(config) model = XxxForSequenceClassification(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, logits = model( result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -211,11 +190,7 @@ class XxxModelTester: ...@@ -211,11 +190,7 @@ class XxxModelTester:
model = XxxForTokenClassification(config=config) model = XxxForTokenClassification(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result) self.check_loss_output(result)
......
...@@ -98,6 +98,7 @@ class AlbertModelTester: ...@@ -98,6 +98,7 @@ class AlbertModelTester:
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
num_hidden_groups=self.num_hidden_groups, num_hidden_groups=self.num_hidden_groups,
return_dict=True,
) )
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -111,18 +112,13 @@ class AlbertModelTester: ...@@ -111,18 +112,13 @@ class AlbertModelTester:
model = AlbertModel(config=config) model = AlbertModel(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids) result = model(input_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size] list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
) )
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_albert_for_pretraining( def create_and_check_albert_for_pretraining(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -130,22 +126,17 @@ class AlbertModelTester: ...@@ -130,22 +126,17 @@ class AlbertModelTester:
model = AlbertForPreTraining(config=config) model = AlbertForPreTraining(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, prediction_scores, sop_scores = model( result = model(
input_ids, input_ids,
attention_mask=input_mask, attention_mask=input_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
labels=token_labels, labels=token_labels,
sentence_order_label=sequence_labels, sentence_order_label=sequence_labels,
) )
result = {
"loss": loss,
"prediction_scores": prediction_scores,
"sop_scores": sop_scores,
}
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size] list(result["prediction_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
) )
self.parent.assertListEqual(list(result["sop_scores"].size()), [self.batch_size, config.num_labels]) self.parent.assertListEqual(list(result["sop_logits"].size()), [self.batch_size, config.num_labels])
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_albert_for_masked_lm( def create_and_check_albert_for_masked_lm(
...@@ -154,16 +145,8 @@ class AlbertModelTester: ...@@ -154,16 +145,8 @@ class AlbertModelTester:
model = AlbertForMaskedLM(config=config) model = AlbertForMaskedLM(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, prediction_scores = model( result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_albert_for_question_answering( def create_and_check_albert_for_question_answering(
...@@ -172,18 +155,13 @@ class AlbertModelTester: ...@@ -172,18 +155,13 @@ class AlbertModelTester:
model = AlbertForQuestionAnswering(config=config) model = AlbertForQuestionAnswering(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, start_logits, end_logits = model( result = model(
input_ids, input_ids,
attention_mask=input_mask, attention_mask=input_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
start_positions=sequence_labels, start_positions=sequence_labels,
end_positions=sequence_labels, end_positions=sequence_labels,
) )
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -195,13 +173,7 @@ class AlbertModelTester: ...@@ -195,13 +173,7 @@ class AlbertModelTester:
model = AlbertForSequenceClassification(config) model = AlbertForSequenceClassification(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, logits = model( result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -212,11 +184,7 @@ class AlbertModelTester: ...@@ -212,11 +184,7 @@ class AlbertModelTester:
model = AlbertForTokenClassification(config=config) model = AlbertForTokenClassification(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -230,16 +198,12 @@ class AlbertModelTester: ...@@ -230,16 +198,12 @@ class AlbertModelTester:
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model( result = model(
multiple_choice_inputs_ids, multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask, attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids, token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels, labels=choice_labels,
) )
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
......
...@@ -238,6 +238,7 @@ class BartHeadTests(unittest.TestCase): ...@@ -238,6 +238,7 @@ class BartHeadTests(unittest.TestCase):
eos_token_id=2, eos_token_id=2,
pad_token_id=1, pad_token_id=1,
bos_token_id=0, bos_token_id=0,
return_dict=True,
) )
return config, input_ids, batch_size return config, input_ids, batch_size
...@@ -247,24 +248,20 @@ class BartHeadTests(unittest.TestCase): ...@@ -247,24 +248,20 @@ class BartHeadTests(unittest.TestCase):
model = BartForSequenceClassification(config) model = BartForSequenceClassification(config)
model.to(torch_device) model.to(torch_device)
outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=labels) outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=labels)
logits = outputs[1]
expected_shape = torch.Size((batch_size, config.num_labels)) expected_shape = torch.Size((batch_size, config.num_labels))
self.assertEqual(logits.shape, expected_shape) self.assertEqual(outputs["logits"].shape, expected_shape)
loss = outputs[0] self.assertIsInstance(outputs["loss"].item(), float)
self.assertIsInstance(loss.item(), float)
def test_question_answering_forward(self): def test_question_answering_forward(self):
config, input_ids, batch_size = self._get_config_and_data() config, input_ids, batch_size = self._get_config_and_data()
sequence_labels = ids_tensor([batch_size], 2).to(torch_device) sequence_labels = ids_tensor([batch_size], 2).to(torch_device)
model = BartForQuestionAnswering(config) model = BartForQuestionAnswering(config)
model.to(torch_device) model.to(torch_device)
loss, start_logits, end_logits, _ = model( outputs = model(input_ids=input_ids, start_positions=sequence_labels, end_positions=sequence_labels,)
input_ids=input_ids, start_positions=sequence_labels, end_positions=sequence_labels,
)
self.assertEqual(start_logits.shape, input_ids.shape) self.assertEqual(outputs["start_logits"].shape, input_ids.shape)
self.assertEqual(end_logits.shape, input_ids.shape) self.assertEqual(outputs["end_logits"].shape, input_ids.shape)
self.assertIsInstance(loss.item(), float) self.assertIsInstance(outputs["loss"].item(), float)
@timeout_decorator.timeout(1) @timeout_decorator.timeout(1)
def test_lm_forward(self): def test_lm_forward(self):
...@@ -272,10 +269,10 @@ class BartHeadTests(unittest.TestCase): ...@@ -272,10 +269,10 @@ class BartHeadTests(unittest.TestCase):
lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device) lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
lm_model = BartForConditionalGeneration(config) lm_model = BartForConditionalGeneration(config)
lm_model.to(torch_device) lm_model.to(torch_device)
loss, logits, enc_features = lm_model(input_ids=input_ids, labels=lm_labels) outputs = lm_model(input_ids=input_ids, labels=lm_labels)
expected_shape = (batch_size, input_ids.shape[1], config.vocab_size) expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
self.assertEqual(logits.shape, expected_shape) self.assertEqual(outputs["logits"].shape, expected_shape)
self.assertIsInstance(loss.item(), float) self.assertIsInstance(outputs["loss"].item(), float)
def test_lm_uneven_forward(self): def test_lm_uneven_forward(self):
config = BartConfig( config = BartConfig(
...@@ -288,13 +285,14 @@ class BartHeadTests(unittest.TestCase): ...@@ -288,13 +285,14 @@ class BartHeadTests(unittest.TestCase):
encoder_ffn_dim=8, encoder_ffn_dim=8,
decoder_ffn_dim=8, decoder_ffn_dim=8,
max_position_embeddings=48, max_position_embeddings=48,
return_dict=True,
) )
lm_model = BartForConditionalGeneration(config).to(torch_device) lm_model = BartForConditionalGeneration(config).to(torch_device)
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device) context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device) summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
loss, logits, enc_features = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary) outputs = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
expected_shape = (*summary.shape, config.vocab_size) expected_shape = (*summary.shape, config.vocab_size)
self.assertEqual(logits.shape, expected_shape) self.assertEqual(outputs["logits"].shape, expected_shape)
def test_generate_beam_search(self): def test_generate_beam_search(self):
input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long().to(torch_device) input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long().to(torch_device)
......
...@@ -120,6 +120,7 @@ class BertModelTester: ...@@ -120,6 +120,7 @@ class BertModelTester:
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
is_decoder=False, is_decoder=False,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
return_dict=True,
) )
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -160,18 +161,13 @@ class BertModelTester: ...@@ -160,18 +161,13 @@ class BertModelTester:
model = BertModel(config=config) model = BertModel(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids) result = model(input_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size] list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
) )
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_bert_model_as_decoder( def create_and_check_bert_model_as_decoder(
self, self,
...@@ -188,29 +184,24 @@ class BertModelTester: ...@@ -188,29 +184,24 @@ class BertModelTester:
model = BertModel(config) model = BertModel(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
sequence_output, pooled_output = model( result = model(
input_ids, input_ids,
attention_mask=input_mask, attention_mask=input_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
) )
sequence_output, pooled_output = model( result = model(
input_ids, input_ids,
attention_mask=input_mask, attention_mask=input_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
) )
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size] list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
) )
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_bert_for_causal_lm( def create_and_check_bert_for_causal_lm(
self, self,
...@@ -227,16 +218,8 @@ class BertModelTester: ...@@ -227,16 +218,8 @@ class BertModelTester:
model = BertLMHeadModel(config=config) model = BertLMHeadModel(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, prediction_scores = model( result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_bert_for_masked_lm( def create_and_check_bert_for_masked_lm(
...@@ -245,16 +228,8 @@ class BertModelTester: ...@@ -245,16 +228,8 @@ class BertModelTester:
model = BertForMaskedLM(config=config) model = BertForMaskedLM(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, prediction_scores = model( result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_bert_model_for_causal_lm_as_decoder( def create_and_check_bert_model_for_causal_lm_as_decoder(
...@@ -272,7 +247,7 @@ class BertModelTester: ...@@ -272,7 +247,7 @@ class BertModelTester:
model = BertLMHeadModel(config=config) model = BertLMHeadModel(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, prediction_scores = model( result = model(
input_ids, input_ids,
attention_mask=input_mask, attention_mask=input_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -280,20 +255,14 @@ class BertModelTester: ...@@ -280,20 +255,14 @@ class BertModelTester:
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
) )
loss, prediction_scores = model( result = model(
input_ids, input_ids,
attention_mask=input_mask, attention_mask=input_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
labels=token_labels, labels=token_labels,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
) )
result = { self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_bert_for_next_sequence_prediction( def create_and_check_bert_for_next_sequence_prediction(
...@@ -302,14 +271,10 @@ class BertModelTester: ...@@ -302,14 +271,10 @@ class BertModelTester:
model = BertForNextSentencePrediction(config=config) model = BertForNextSentencePrediction(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, seq_relationship_score = model( result = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, next_sentence_label=sequence_labels, input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, next_sentence_label=sequence_labels,
) )
result = { self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, 2])
"loss": loss,
"seq_relationship_score": seq_relationship_score,
}
self.parent.assertListEqual(list(result["seq_relationship_score"].size()), [self.batch_size, 2])
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_bert_for_pretraining( def create_and_check_bert_for_pretraining(
...@@ -318,22 +283,17 @@ class BertModelTester: ...@@ -318,22 +283,17 @@ class BertModelTester:
model = BertForPreTraining(config=config) model = BertForPreTraining(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, prediction_scores, seq_relationship_score = model( result = model(
input_ids, input_ids,
attention_mask=input_mask, attention_mask=input_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
labels=token_labels, labels=token_labels,
next_sentence_label=sequence_labels, next_sentence_label=sequence_labels,
) )
result = {
"loss": loss,
"prediction_scores": prediction_scores,
"seq_relationship_score": seq_relationship_score,
}
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size] list(result["prediction_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
) )
self.parent.assertListEqual(list(result["seq_relationship_score"].size()), [self.batch_size, 2]) self.parent.assertListEqual(list(result["seq_relationship_logits"].size()), [self.batch_size, 2])
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_bert_for_question_answering( def create_and_check_bert_for_question_answering(
...@@ -342,18 +302,13 @@ class BertModelTester: ...@@ -342,18 +302,13 @@ class BertModelTester:
model = BertForQuestionAnswering(config=config) model = BertForQuestionAnswering(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, start_logits, end_logits = model( result = model(
input_ids, input_ids,
attention_mask=input_mask, attention_mask=input_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
start_positions=sequence_labels, start_positions=sequence_labels,
end_positions=sequence_labels, end_positions=sequence_labels,
) )
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -365,13 +320,7 @@ class BertModelTester: ...@@ -365,13 +320,7 @@ class BertModelTester:
model = BertForSequenceClassification(config) model = BertForSequenceClassification(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, logits = model( result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -382,11 +331,7 @@ class BertModelTester: ...@@ -382,11 +331,7 @@ class BertModelTester:
model = BertForTokenClassification(config=config) model = BertForTokenClassification(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -400,16 +345,12 @@ class BertModelTester: ...@@ -400,16 +345,12 @@ class BertModelTester:
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model( result = model(
multiple_choice_inputs_ids, multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask, attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids, token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels, labels=choice_labels,
) )
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
self.check_loss_output(result) self.check_loss_output(result)
......
...@@ -28,13 +28,13 @@ if is_torch_available(): ...@@ -28,13 +28,13 @@ if is_torch_available():
class CamembertModelIntegrationTest(unittest.TestCase): class CamembertModelIntegrationTest(unittest.TestCase):
@slow @slow
def test_output_embeds_base_model(self): def test_output_embeds_base_model(self):
model = CamembertModel.from_pretrained("camembert-base") model = CamembertModel.from_pretrained("camembert-base", return_dict=True)
model.to(torch_device) model.to(torch_device)
input_ids = torch.tensor( input_ids = torch.tensor(
[[5, 121, 11, 660, 16, 730, 25543, 110, 83, 6]], device=torch_device, dtype=torch.long, [[5, 121, 11, 660, 16, 730, 25543, 110, 83, 6]], device=torch_device, dtype=torch.long,
) # J'aime le camembert ! ) # J'aime le camembert !
output = model(input_ids)[0] output = model(input_ids)["last_hidden_state"]
expected_shape = torch.Size((1, 10, 768)) expected_shape = torch.Size((1, 10, 768))
self.assertEqual(output.shape, expected_shape) self.assertEqual(output.shape, expected_shape)
# compare the actual values for a slice. # compare the actual values for a slice.
......
...@@ -74,7 +74,6 @@ class ModelTesterMixin: ...@@ -74,7 +74,6 @@ class ModelTesterMixin:
def test_save_load(self): def test_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)
......
...@@ -88,9 +88,10 @@ class CTRLModelTester: ...@@ -88,9 +88,10 @@ class CTRLModelTester:
# hidden_dropout_prob=self.hidden_dropout_prob, # hidden_dropout_prob=self.hidden_dropout_prob,
# attention_probs_dropout_prob=self.attention_probs_dropout_prob, # attention_probs_dropout_prob=self.attention_probs_dropout_prob,
n_positions=self.max_position_embeddings, n_positions=self.max_position_embeddings,
n_ctx=self.max_position_embeddings n_ctx=self.max_position_embeddings,
# type_vocab_size=self.type_vocab_size, # type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range # initializer_range=self.initializer_range,
return_dict=True,
) )
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
...@@ -117,29 +118,20 @@ class CTRLModelTester: ...@@ -117,29 +118,20 @@ class CTRLModelTester:
model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask) model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
model(input_ids, token_type_ids=token_type_ids) model(input_ids, token_type_ids=token_type_ids)
sequence_output, presents = model(input_ids) result = model(input_ids)
result = {
"sequence_output": sequence_output,
"presents": presents,
}
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size] list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
) )
self.parent.assertEqual(len(result["presents"]), config.n_layer) self.parent.assertEqual(len(result["past_key_values"]), config.n_layer)
def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = CTRLLMHeadModel(config) model = CTRLLMHeadModel(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, lm_logits, _ = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
result = {"loss": loss, "lm_logits": lm_logits}
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
list(result["lm_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
......
...@@ -110,6 +110,7 @@ if is_torch_available(): ...@@ -110,6 +110,7 @@ if is_torch_available():
attention_dropout=self.attention_probs_dropout_prob, attention_dropout=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
return_dict=True,
) )
return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -123,14 +124,10 @@ if is_torch_available(): ...@@ -123,14 +124,10 @@ if is_torch_available():
model = DistilBertModel(config=config) model = DistilBertModel(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
(sequence_output,) = model(input_ids, input_mask) result = model(input_ids, input_mask)
(sequence_output,) = model(input_ids) result = model(input_ids)
result = {
"sequence_output": sequence_output,
}
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size] list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
) )
def create_and_check_distilbert_for_masked_lm( def create_and_check_distilbert_for_masked_lm(
...@@ -139,13 +136,9 @@ if is_torch_available(): ...@@ -139,13 +136,9 @@ if is_torch_available():
model = DistilBertForMaskedLM(config=config) model = DistilBertForMaskedLM(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=token_labels) result = model(input_ids, attention_mask=input_mask, labels=token_labels)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size] list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
) )
self.check_loss_output(result) self.check_loss_output(result)
...@@ -155,14 +148,9 @@ if is_torch_available(): ...@@ -155,14 +148,9 @@ if is_torch_available():
model = DistilBertForQuestionAnswering(config=config) model = DistilBertForQuestionAnswering(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, start_logits, end_logits = model( result = model(
input_ids, attention_mask=input_mask, start_positions=sequence_labels, end_positions=sequence_labels input_ids, attention_mask=input_mask, start_positions=sequence_labels, end_positions=sequence_labels
) )
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -174,11 +162,7 @@ if is_torch_available(): ...@@ -174,11 +162,7 @@ if is_torch_available():
model = DistilBertForSequenceClassification(config) model = DistilBertForSequenceClassification(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, labels=sequence_labels) result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -190,11 +174,7 @@ if is_torch_available(): ...@@ -190,11 +174,7 @@ if is_torch_available():
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, labels=token_labels) result = model(input_ids, attention_mask=input_mask, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels] list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]
) )
...@@ -209,13 +189,9 @@ if is_torch_available(): ...@@ -209,13 +189,9 @@ if is_torch_available():
model.eval() model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model( result = model(
multiple_choice_inputs_ids, attention_mask=multiple_choice_input_mask, labels=choice_labels, multiple_choice_inputs_ids, attention_mask=multiple_choice_input_mask, labels=choice_labels,
) )
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
self.check_loss_output(result) self.check_loss_output(result)
......
...@@ -115,6 +115,7 @@ class DPRModelTester: ...@@ -115,6 +115,7 @@ class DPRModelTester:
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
is_decoder=False, is_decoder=False,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
return_dict=True,
) )
config = DPRConfig(projection_dim=self.projection_dim, **config.to_dict()) config = DPRConfig(projection_dim=self.projection_dim, **config.to_dict())
...@@ -126,15 +127,11 @@ class DPRModelTester: ...@@ -126,15 +127,11 @@ class DPRModelTester:
model = DPRContextEncoder(config=config) model = DPRContextEncoder(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
embeddings = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)[0] result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
embeddings = model(input_ids, token_type_ids=token_type_ids)[0] result = model(input_ids, token_type_ids=token_type_ids)
embeddings = model(input_ids)[0] result = model(input_ids)
result = {
"embeddings": embeddings,
}
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["embeddings"].size()), [self.batch_size, self.projection_dim or self.hidden_size] list(result["pooler_output"].size()), [self.batch_size, self.projection_dim or self.hidden_size]
) )
def create_and_check_dpr_question_encoder( def create_and_check_dpr_question_encoder(
...@@ -143,15 +140,11 @@ class DPRModelTester: ...@@ -143,15 +140,11 @@ class DPRModelTester:
model = DPRQuestionEncoder(config=config) model = DPRQuestionEncoder(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
embeddings = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)[0] result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
embeddings = model(input_ids, token_type_ids=token_type_ids)[0] result = model(input_ids, token_type_ids=token_type_ids)
embeddings = model(input_ids)[0] result = model(input_ids)
result = {
"embeddings": embeddings,
}
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["embeddings"].size()), [self.batch_size, self.projection_dim or self.hidden_size] list(result["pooler_output"].size()), [self.batch_size, self.projection_dim or self.hidden_size]
) )
def create_and_check_dpr_reader( def create_and_check_dpr_reader(
...@@ -160,12 +153,7 @@ class DPRModelTester: ...@@ -160,12 +153,7 @@ class DPRModelTester:
model = DPRReader(config=config) model = DPRReader(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
start_logits, end_logits, relevance_logits, *_ = model(input_ids, attention_mask=input_mask,) result = model(input_ids, attention_mask=input_mask,)
result = {
"relevance_logits": relevance_logits,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["relevance_logits"].size()), [self.batch_size]) self.parent.assertListEqual(list(result["relevance_logits"].size()), [self.batch_size])
......
...@@ -97,6 +97,7 @@ class ElectraModelTester: ...@@ -97,6 +97,7 @@ class ElectraModelTester:
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
is_decoder=False, is_decoder=False,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
return_dict=True,
) )
return ( return (
...@@ -127,15 +128,11 @@ class ElectraModelTester: ...@@ -127,15 +128,11 @@ class ElectraModelTester:
model = ElectraModel(config=config) model = ElectraModel(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
(sequence_output,) = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
(sequence_output,) = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids, token_type_ids=token_type_ids)
(sequence_output,) = model(input_ids) result = model(input_ids)
result = {
"sequence_output": sequence_output,
}
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size] list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
) )
def create_and_check_electra_for_masked_lm( def create_and_check_electra_for_masked_lm(
...@@ -152,16 +149,8 @@ class ElectraModelTester: ...@@ -152,16 +149,8 @@ class ElectraModelTester:
model = ElectraForMaskedLM(config=config) model = ElectraForMaskedLM(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, prediction_scores = model( result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_electra_for_token_classification( def create_and_check_electra_for_token_classification(
...@@ -179,11 +168,7 @@ class ElectraModelTester: ...@@ -179,11 +168,7 @@ class ElectraModelTester:
model = ElectraForTokenClassification(config=config) model = ElectraForTokenClassification(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -202,13 +187,7 @@ class ElectraModelTester: ...@@ -202,13 +187,7 @@ class ElectraModelTester:
model = ElectraForPreTraining(config=config) model = ElectraForPreTraining(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, logits = model( result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=fake_token_labels)
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=fake_token_labels
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -227,13 +206,7 @@ class ElectraModelTester: ...@@ -227,13 +206,7 @@ class ElectraModelTester:
model = ElectraForSequenceClassification(config) model = ElectraForSequenceClassification(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, logits = model( result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -251,18 +224,13 @@ class ElectraModelTester: ...@@ -251,18 +224,13 @@ class ElectraModelTester:
model = ElectraForQuestionAnswering(config=config) model = ElectraForQuestionAnswering(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, start_logits, end_logits = model( result = model(
input_ids, input_ids,
attention_mask=input_mask, attention_mask=input_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
start_positions=sequence_labels, start_positions=sequence_labels,
end_positions=sequence_labels, end_positions=sequence_labels,
) )
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -285,16 +253,12 @@ class ElectraModelTester: ...@@ -285,16 +253,12 @@ class ElectraModelTester:
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model( result = model(
multiple_choice_inputs_ids, multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask, attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids, token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels, labels=choice_labels,
) )
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
self.check_loss_output(result) self.check_loss_output(result)
......
...@@ -110,6 +110,7 @@ class FlaubertModelTester(object): ...@@ -110,6 +110,7 @@ class FlaubertModelTester(object):
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
summary_type=self.summary_type, summary_type=self.summary_type,
use_proj=self.use_proj, use_proj=self.use_proj,
return_dict=True,
) )
return ( return (
...@@ -142,15 +143,11 @@ class FlaubertModelTester(object): ...@@ -142,15 +143,11 @@ class FlaubertModelTester(object):
model = FlaubertModel(config=config) model = FlaubertModel(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
outputs = model(input_ids, lengths=input_lengths, langs=token_type_ids) result = model(input_ids, lengths=input_lengths, langs=token_type_ids)
outputs = model(input_ids, langs=token_type_ids) result = model(input_ids, langs=token_type_ids)
outputs = model(input_ids) result = model(input_ids)
sequence_output = outputs[0]
result = {
"sequence_output": sequence_output,
}
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size] list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
) )
def create_and_check_flaubert_lm_head( def create_and_check_flaubert_lm_head(
...@@ -169,13 +166,7 @@ class FlaubertModelTester(object): ...@@ -169,13 +166,7 @@ class FlaubertModelTester(object):
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, logits = model(input_ids, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, token_type_ids=token_type_ids, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
...@@ -195,16 +186,9 @@ class FlaubertModelTester(object): ...@@ -195,16 +186,9 @@ class FlaubertModelTester(object):
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
outputs = model(input_ids) result = model(input_ids)
outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
loss, start_logits, end_logits = outputs
result = { result = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -225,10 +209,9 @@ class FlaubertModelTester(object): ...@@ -225,10 +209,9 @@ class FlaubertModelTester(object):
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
outputs = model(input_ids) result = model(input_ids)
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = outputs
outputs = model( result_with_labels = model(
input_ids, input_ids,
start_positions=sequence_labels, start_positions=sequence_labels,
end_positions=sequence_labels, end_positions=sequence_labels,
...@@ -237,7 +220,7 @@ class FlaubertModelTester(object): ...@@ -237,7 +220,7 @@ class FlaubertModelTester(object):
p_mask=input_mask, p_mask=input_mask,
) )
outputs = model( result_with_labels = model(
input_ids, input_ids,
start_positions=sequence_labels, start_positions=sequence_labels,
end_positions=sequence_labels, end_positions=sequence_labels,
...@@ -245,22 +228,13 @@ class FlaubertModelTester(object): ...@@ -245,22 +228,13 @@ class FlaubertModelTester(object):
is_impossible=is_impossible_labels, is_impossible=is_impossible_labels,
) )
(total_loss,) = outputs (total_loss,) = result_with_labels.to_tuple()
outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels) result_with_labels = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
(total_loss,) = outputs (total_loss,) = result_with_labels.to_tuple()
result = { self.parent.assertListEqual(list(result_with_labels["loss"].size()), [])
"loss": total_loss,
"start_top_log_probs": start_top_log_probs,
"start_top_index": start_top_index,
"end_top_log_probs": end_top_log_probs,
"end_top_index": end_top_index,
"cls_logits": cls_logits,
}
self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top] list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top]
) )
...@@ -292,13 +266,8 @@ class FlaubertModelTester(object): ...@@ -292,13 +266,8 @@ class FlaubertModelTester(object):
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
(logits,) = model(input_ids) result = model(input_ids)
loss, logits = model(input_ids, labels=sequence_labels) result = model(input_ids, labels=sequence_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size])
...@@ -320,11 +289,7 @@ class FlaubertModelTester(object): ...@@ -320,11 +289,7 @@ class FlaubertModelTester(object):
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, labels=token_labels) result = model(input_ids, attention_mask=input_mask, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -347,16 +312,12 @@ class FlaubertModelTester(object): ...@@ -347,16 +312,12 @@ class FlaubertModelTester(object):
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model( result = model(
multiple_choice_inputs_ids, multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask, attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids, token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels, labels=choice_labels,
) )
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
self.check_loss_output(result) self.check_loss_output(result)
......
...@@ -122,9 +122,10 @@ class GPT2ModelTester: ...@@ -122,9 +122,10 @@ class GPT2ModelTester:
n_positions=self.max_position_embeddings, n_positions=self.max_position_embeddings,
n_ctx=self.max_position_embeddings, n_ctx=self.max_position_embeddings,
# type_vocab_size=self.type_vocab_size, # type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range # initializer_range=self.initializer_range,
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
return_dict=True,
) )
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
...@@ -149,18 +150,14 @@ class GPT2ModelTester: ...@@ -149,18 +150,14 @@ class GPT2ModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask) result = model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
model(input_ids, token_type_ids=token_type_ids) result = model(input_ids, token_type_ids=token_type_ids)
sequence_output, presents = model(input_ids) result = model(input_ids)
result = {
"sequence_output": sequence_output,
"presents": presents,
}
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size], list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size],
) )
self.parent.assertEqual(len(result["presents"]), config.n_layer) self.parent.assertEqual(len(result["past_key_values"]), config.n_layer)
def create_and_check_gpt2_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): def create_and_check_gpt2_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = GPT2Model(config=config) model = GPT2Model(config=config)
...@@ -175,7 +172,7 @@ class GPT2ModelTester: ...@@ -175,7 +172,7 @@ class GPT2ModelTester:
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
output, past = outputs output, past = outputs.to_tuple()
# create hypothetical next token and extent to next_input_ids # create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
...@@ -185,8 +182,8 @@ class GPT2ModelTester: ...@@ -185,8 +182,8 @@ class GPT2ModelTester:
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
next_token_type_ids = torch.cat([token_type_ids, next_token_types], dim=-1) next_token_type_ids = torch.cat([token_type_ids, next_token_types], dim=-1)
output_from_no_past, _ = model(next_input_ids, token_type_ids=next_token_type_ids) output_from_no_past = model(next_input_ids, token_type_ids=next_token_type_ids)["last_hidden_state"]
output_from_past, _ = model(next_tokens, token_type_ids=next_token_types, past=past) output_from_past = model(next_tokens, token_type_ids=next_token_types, past=past)["last_hidden_state"]
# select random slice # select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
...@@ -209,7 +206,7 @@ class GPT2ModelTester: ...@@ -209,7 +206,7 @@ class GPT2ModelTester:
attn_mask[:, half_seq_length:] = 0 attn_mask[:, half_seq_length:] = 0
# first forward pass # first forward pass
output, past = model(input_ids, attention_mask=attn_mask) output, past = model(input_ids, attention_mask=attn_mask).to_tuple()
# create hypothetical next token and extent to next_input_ids # create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
...@@ -226,8 +223,8 @@ class GPT2ModelTester: ...@@ -226,8 +223,8 @@ class GPT2ModelTester:
) )
# get two different outputs # get two different outputs
output_from_no_past, _ = model(next_input_ids, attention_mask=attn_mask) output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
output_from_past, _ = model(next_tokens, past=past, attention_mask=attn_mask) output_from_past = model(next_tokens, past=past, attention_mask=attn_mask)["last_hidden_state"]
# select random slice # select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
...@@ -242,13 +239,10 @@ class GPT2ModelTester: ...@@ -242,13 +239,10 @@ class GPT2ModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, lm_logits, _ = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
result = {"loss": loss, "lm_logits": lm_logits}
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["lm_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size], list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
) )
def create_and_check_double_lm_head_model( def create_and_check_double_lm_head_model(
...@@ -270,11 +264,8 @@ class GPT2ModelTester: ...@@ -270,11 +264,8 @@ class GPT2ModelTester:
"labels": multiple_choice_inputs_ids, "labels": multiple_choice_inputs_ids,
} }
loss, lm_logits, mc_logits, _ = model(**inputs) result = model(**inputs)
self.parent.assertListEqual(list(result["lm_loss"].size()), [])
result = {"loss": loss, "lm_logits": lm_logits, "mc_logits": mc_logits}
self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["lm_logits"].size()), [self.batch_size, self.num_choices, self.seq_length, self.vocab_size], list(result["lm_logits"].size()), [self.batch_size, self.num_choices, self.seq_length, self.vocab_size],
) )
......
...@@ -108,6 +108,7 @@ class LongformerModelTester: ...@@ -108,6 +108,7 @@ class LongformerModelTester:
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
attention_window=self.attention_window, attention_window=self.attention_window,
return_dict=True,
) )
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -123,8 +124,8 @@ class LongformerModelTester: ...@@ -123,8 +124,8 @@ class LongformerModelTester:
model.eval() model.eval()
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
output_with_mask = model(input_ids, attention_mask=attention_mask)[0] output_with_mask = model(input_ids, attention_mask=attention_mask)["last_hidden_state"]
output_without_mask = model(input_ids)[0] output_without_mask = model(input_ids)["last_hidden_state"]
self.parent.assertTrue(torch.allclose(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], atol=1e-4)) self.parent.assertTrue(torch.allclose(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], atol=1e-4))
def create_and_check_longformer_model( def create_and_check_longformer_model(
...@@ -133,18 +134,13 @@ class LongformerModelTester: ...@@ -133,18 +134,13 @@ class LongformerModelTester:
model = LongformerModel(config=config) model = LongformerModel(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids) result = model(input_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size] list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
) )
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_longformer_model_with_global_attention_mask( def create_and_check_longformer_model_with_global_attention_mask(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -156,25 +152,19 @@ class LongformerModelTester: ...@@ -156,25 +152,19 @@ class LongformerModelTester:
global_attention_mask[:, input_mask.shape[-1] // 2] = 0 global_attention_mask[:, input_mask.shape[-1] // 2] = 0
global_attention_mask = global_attention_mask.to(torch_device) global_attention_mask = global_attention_mask.to(torch_device)
sequence_output, pooled_output = model( result = model(
input_ids, input_ids,
attention_mask=input_mask, attention_mask=input_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
) )
sequence_output, pooled_output = model( result = model(input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask)
input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask result = model(input_ids, global_attention_mask=global_attention_mask)
)
sequence_output, pooled_output = model(input_ids, global_attention_mask=global_attention_mask)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size] list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
) )
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_longformer_for_masked_lm( def create_and_check_longformer_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -182,16 +172,8 @@ class LongformerModelTester: ...@@ -182,16 +172,8 @@ class LongformerModelTester:
model = LongformerForMaskedLM(config=config) model = LongformerForMaskedLM(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, prediction_scores = model( result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_longformer_for_question_answering( def create_and_check_longformer_for_question_answering(
...@@ -200,7 +182,7 @@ class LongformerModelTester: ...@@ -200,7 +182,7 @@ class LongformerModelTester:
model = LongformerForQuestionAnswering(config=config) model = LongformerForQuestionAnswering(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, start_logits, end_logits = model( result = model(
input_ids, input_ids,
attention_mask=input_mask, attention_mask=input_mask,
global_attention_mask=input_mask, global_attention_mask=input_mask,
...@@ -208,11 +190,6 @@ class LongformerModelTester: ...@@ -208,11 +190,6 @@ class LongformerModelTester:
start_positions=sequence_labels, start_positions=sequence_labels,
end_positions=sequence_labels, end_positions=sequence_labels,
) )
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -224,13 +201,7 @@ class LongformerModelTester: ...@@ -224,13 +201,7 @@ class LongformerModelTester:
model = LongformerForSequenceClassification(config) model = LongformerForSequenceClassification(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, logits = model( result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -241,11 +212,7 @@ class LongformerModelTester: ...@@ -241,11 +212,7 @@ class LongformerModelTester:
model = LongformerForTokenClassification(config=config) model = LongformerForTokenClassification(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -260,17 +227,13 @@ class LongformerModelTester: ...@@ -260,17 +227,13 @@ class LongformerModelTester:
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model( result = model(
multiple_choice_inputs_ids, multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask, attention_mask=multiple_choice_input_mask,
global_attention_mask=multiple_choice_input_mask, global_attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids, token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels, labels=choice_labels,
) )
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
self.check_loss_output(result) self.check_loss_output(result)
......
...@@ -114,13 +114,14 @@ class MBartEnroIntegrationTest(AbstractMBartIntegrationTest): ...@@ -114,13 +114,14 @@ class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
decoder_ffn_dim=32, decoder_ffn_dim=32,
max_position_embeddings=48, max_position_embeddings=48,
add_final_layer_norm=True, add_final_layer_norm=True,
return_dict=True,
) )
lm_model = BartForConditionalGeneration(config).to(torch_device) lm_model = BartForConditionalGeneration(config).to(torch_device)
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device) context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device) summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
loss, logits, enc_features = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary) result = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
expected_shape = (*summary.shape, config.vocab_size) expected_shape = (*summary.shape, config.vocab_size)
self.assertEqual(logits.shape, expected_shape) self.assertEqual(result["logits"].shape, expected_shape)
@require_torch @require_torch
......
...@@ -122,6 +122,7 @@ class MobileBertModelTester: ...@@ -122,6 +122,7 @@ class MobileBertModelTester:
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
is_decoder=False, is_decoder=False,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
return_dict=True,
) )
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -162,18 +163,14 @@ class MobileBertModelTester: ...@@ -162,18 +163,14 @@ class MobileBertModelTester:
model = MobileBertModel(config=config) model = MobileBertModel(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids) result = model(input_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size] list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
) )
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_mobilebert_model_as_decoder( def create_and_check_mobilebert_model_as_decoder(
self, self,
...@@ -190,29 +187,25 @@ class MobileBertModelTester: ...@@ -190,29 +187,25 @@ class MobileBertModelTester:
model = MobileBertModel(config) model = MobileBertModel(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
sequence_output, pooled_output = model( result = model(
input_ids, input_ids,
attention_mask=input_mask, attention_mask=input_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
) )
sequence_output, pooled_output = model( result = model(
input_ids, input_ids,
attention_mask=input_mask, attention_mask=input_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
) )
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size] list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
) )
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_mobilebert_for_masked_lm( def create_and_check_mobilebert_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -220,16 +213,8 @@ class MobileBertModelTester: ...@@ -220,16 +213,8 @@ class MobileBertModelTester:
model = MobileBertForMaskedLM(config=config) model = MobileBertForMaskedLM(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, prediction_scores = model( result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_mobilebert_for_next_sequence_prediction( def create_and_check_mobilebert_for_next_sequence_prediction(
...@@ -238,14 +223,10 @@ class MobileBertModelTester: ...@@ -238,14 +223,10 @@ class MobileBertModelTester:
model = MobileBertForNextSentencePrediction(config=config) model = MobileBertForNextSentencePrediction(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, seq_relationship_score = model( result = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, next_sentence_label=sequence_labels, input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, next_sentence_label=sequence_labels,
) )
result = { self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, 2])
"loss": loss,
"seq_relationship_score": seq_relationship_score,
}
self.parent.assertListEqual(list(result["seq_relationship_score"].size()), [self.batch_size, 2])
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_mobilebert_for_pretraining( def create_and_check_mobilebert_for_pretraining(
...@@ -254,22 +235,17 @@ class MobileBertModelTester: ...@@ -254,22 +235,17 @@ class MobileBertModelTester:
model = MobileBertForPreTraining(config=config) model = MobileBertForPreTraining(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, prediction_scores, seq_relationship_score = model( result = model(
input_ids, input_ids,
attention_mask=input_mask, attention_mask=input_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
labels=token_labels, labels=token_labels,
next_sentence_label=sequence_labels, next_sentence_label=sequence_labels,
) )
result = {
"loss": loss,
"prediction_scores": prediction_scores,
"seq_relationship_score": seq_relationship_score,
}
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size] list(result["prediction_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
) )
self.parent.assertListEqual(list(result["seq_relationship_score"].size()), [self.batch_size, 2]) self.parent.assertListEqual(list(result["seq_relationship_logits"].size()), [self.batch_size, 2])
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_mobilebert_for_question_answering( def create_and_check_mobilebert_for_question_answering(
...@@ -278,18 +254,13 @@ class MobileBertModelTester: ...@@ -278,18 +254,13 @@ class MobileBertModelTester:
model = MobileBertForQuestionAnswering(config=config) model = MobileBertForQuestionAnswering(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, start_logits, end_logits = model( result = model(
input_ids, input_ids,
attention_mask=input_mask, attention_mask=input_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
start_positions=sequence_labels, start_positions=sequence_labels,
end_positions=sequence_labels, end_positions=sequence_labels,
) )
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -301,13 +272,7 @@ class MobileBertModelTester: ...@@ -301,13 +272,7 @@ class MobileBertModelTester:
model = MobileBertForSequenceClassification(config) model = MobileBertForSequenceClassification(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, logits = model( result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -318,11 +283,7 @@ class MobileBertModelTester: ...@@ -318,11 +283,7 @@ class MobileBertModelTester:
model = MobileBertForTokenClassification(config=config) model = MobileBertForTokenClassification(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -336,16 +297,12 @@ class MobileBertModelTester: ...@@ -336,16 +297,12 @@ class MobileBertModelTester:
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model( result = model(
multiple_choice_inputs_ids, multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask, attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids, token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels, labels=choice_labels,
) )
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
self.check_loss_output(result) self.check_loss_output(result)
......
...@@ -85,9 +85,10 @@ class OpenAIGPTModelTester: ...@@ -85,9 +85,10 @@ class OpenAIGPTModelTester:
# hidden_dropout_prob=self.hidden_dropout_prob, # hidden_dropout_prob=self.hidden_dropout_prob,
# attention_probs_dropout_prob=self.attention_probs_dropout_prob, # attention_probs_dropout_prob=self.attention_probs_dropout_prob,
n_positions=self.max_position_embeddings, n_positions=self.max_position_embeddings,
n_ctx=self.max_position_embeddings n_ctx=self.max_position_embeddings,
# type_vocab_size=self.type_vocab_size, # type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range # initializer_range=self.initializer_range
return_dict=True,
) )
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
...@@ -110,13 +111,12 @@ class OpenAIGPTModelTester: ...@@ -110,13 +111,12 @@ class OpenAIGPTModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask) result = model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
model(input_ids, token_type_ids=token_type_ids) result = model(input_ids, token_type_ids=token_type_ids)
(sequence_output,) = model(input_ids) result = model(input_ids)
result = {"sequence_output": sequence_output}
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size], list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size],
) )
def create_and_check_lm_head_model(self, config, input_ids, head_mask, token_type_ids, *args): def create_and_check_lm_head_model(self, config, input_ids, head_mask, token_type_ids, *args):
...@@ -124,13 +124,10 @@ class OpenAIGPTModelTester: ...@@ -124,13 +124,10 @@ class OpenAIGPTModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, lm_logits = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
result = {"loss": loss, "lm_logits": lm_logits}
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["lm_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size], list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
) )
def create_and_check_double_lm_head_model(self, config, input_ids, head_mask, token_type_ids, *args): def create_and_check_double_lm_head_model(self, config, input_ids, head_mask, token_type_ids, *args):
...@@ -138,11 +135,8 @@ class OpenAIGPTModelTester: ...@@ -138,11 +135,8 @@ class OpenAIGPTModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, lm_logits, mc_logits = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
self.parent.assertListEqual(list(result["lm_loss"].size()), [])
result = {"loss": loss, "lm_logits": lm_logits}
self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["lm_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size], list(result["lm_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
) )
......
...@@ -165,6 +165,7 @@ class ReformerModelTester: ...@@ -165,6 +165,7 @@ class ReformerModelTester:
attn_layers=self.attn_layers, attn_layers=self.attn_layers,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
hash_seed=self.hash_seed, hash_seed=self.hash_seed,
return_dict=True,
) )
return ( return (
...@@ -181,15 +182,12 @@ class ReformerModelTester: ...@@ -181,15 +182,12 @@ class ReformerModelTester:
model = ReformerModel(config=config) model = ReformerModel(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
sequence_output, _ = model(input_ids, attention_mask=input_mask) result = model(input_ids, attention_mask=input_mask)
sequence_output, _ = model(input_ids) result = model(input_ids)
result = {
"sequence_output": sequence_output,
}
# 2 * hidden_size because we use reversible resnet layers # 2 * hidden_size because we use reversible resnet layers
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size], list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size],
) )
def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels): def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels):
...@@ -198,7 +196,7 @@ class ReformerModelTester: ...@@ -198,7 +196,7 @@ class ReformerModelTester:
model = ReformerForMaskedLM(config=config) model = ReformerForMaskedLM(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0] loss = model(input_ids, attention_mask=input_mask, labels=input_ids)["loss"]
loss.backward() loss.backward()
def create_and_check_reformer_with_lm(self, config, input_ids, input_mask, choice_labels): def create_and_check_reformer_with_lm(self, config, input_ids, input_mask, choice_labels):
...@@ -207,13 +205,9 @@ class ReformerModelTester: ...@@ -207,13 +205,9 @@ class ReformerModelTester:
model = ReformerModelWithLMHead(config=config) model = ReformerModelWithLMHead(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, prediction_scores, _ = model(input_ids, attention_mask=input_mask, labels=input_ids) result = model(input_ids, attention_mask=input_mask, labels=input_ids)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size], list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
) )
self.check_loss_output(result) self.check_loss_output(result)
...@@ -222,13 +216,9 @@ class ReformerModelTester: ...@@ -222,13 +216,9 @@ class ReformerModelTester:
model = ReformerForMaskedLM(config=config) model = ReformerForMaskedLM(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids) result = model(input_ids, attention_mask=input_mask, labels=input_ids)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size], list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
) )
self.check_loss_output(result) self.check_loss_output(result)
...@@ -325,7 +315,7 @@ class ReformerModelTester: ...@@ -325,7 +315,7 @@ class ReformerModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)[0] hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)["last_hidden_state"]
self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3)) self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask, choice_labels): def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask, choice_labels):
...@@ -408,7 +398,7 @@ class ReformerModelTester: ...@@ -408,7 +398,7 @@ class ReformerModelTester:
model.to(torch_device) model.to(torch_device)
model.half() model.half()
model.eval() model.eval()
output = model(input_ids, attention_mask=input_mask)[0] output = model(input_ids, attention_mask=input_mask)["last_input_state"]
self.parent.assertFalse(torch.isnan(output).any().item()) self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_reformer_model_generate(self, config, input_ids, input_mask, choice_labels): def create_and_check_reformer_model_generate(self, config, input_ids, input_mask, choice_labels):
...@@ -444,21 +434,16 @@ class ReformerModelTester: ...@@ -444,21 +434,16 @@ class ReformerModelTester:
model = ReformerForMaskedLM(config=config) model = ReformerForMaskedLM(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
output_logits = model(input_ids, attention_mask=input_mask)[0] output_logits = model(input_ids, attention_mask=input_mask)["logits"]
self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1]) self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1])
def create_and_check_reformer_for_question_answering(self, config, input_ids, input_mask, choice_labels): def create_and_check_reformer_for_question_answering(self, config, input_ids, input_mask, choice_labels):
model = ReformerForQuestionAnswering(config=config) model = ReformerForQuestionAnswering(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, start_logits, end_logits = model( result = model(
input_ids, attention_mask=input_mask, start_positions=choice_labels, end_positions=choice_labels, input_ids, attention_mask=input_mask, start_positions=choice_labels, end_positions=choice_labels,
) )
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -474,11 +459,11 @@ class ReformerModelTester: ...@@ -474,11 +459,11 @@ class ReformerModelTester:
input_ids_second = input_ids[:, -1:] input_ids_second = input_ids[:, -1:]
# return saved cache # return saved cache
_, past_buckets_states = model(input_ids_first, use_cache=True) past_buckets_states = model(input_ids_first, use_cache=True)["past_buckets_states"]
# calculate last output with and without cache # calculate last output with and without cache
outputs_with_cache, _ = model(input_ids_second, past_buckets_states=past_buckets_states, use_cache=True) outputs_with_cache = model(input_ids_second, past_buckets_states=past_buckets_states, use_cache=True)["logits"]
outputs_without_cache = model(input_ids)[0][:, -1] outputs_without_cache = model(input_ids)["logits"][:, -1]
# select random slice idx # select random slice idx
random_slice_idx = torch.randint(outputs_without_cache.shape[-1], (1, 1), device=torch_device).item() random_slice_idx = torch.randint(outputs_without_cache.shape[-1], (1, 1), device=torch_device).item()
...@@ -504,11 +489,7 @@ class ReformerModelTester: ...@@ -504,11 +489,7 @@ class ReformerModelTester:
model = ReformerForSequenceClassification(config) model = ReformerForSequenceClassification(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, labels=sequence_labels) result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
self.check_loss_output(result) self.check_loss_output(result)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment