Unverified Commit 5deed37f authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

cleanup torch unittests (#6196)

* improve unit tests

this is a sample of one test according to the request in https://github.com/huggingface/transformers/issues/5973
before I apply it to the rest

* batch 1

* batch 2

* batch 3

* batch 4

* batch 5

* style

* non-tf template

* last deletion of check_loss_output
parent b390a567
...@@ -126,9 +126,6 @@ class XxxModelTester: ...@@ -126,9 +126,6 @@ class XxxModelTester:
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_xxx_model( def create_and_check_xxx_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
...@@ -138,10 +135,8 @@ class XxxModelTester: ...@@ -138,10 +135,8 @@ class XxxModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids) result = model(input_ids)
self.parent.assertListEqual( self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size] self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
)
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_xxx_for_masked_lm( def create_and_check_xxx_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -152,8 +147,7 @@ class XxxModelTester: ...@@ -152,8 +147,7 @@ class XxxModelTester:
result = model( result = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels
) )
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.check_loss_output(result)
def create_and_check_xxx_for_question_answering( def create_and_check_xxx_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -168,9 +162,8 @@ class XxxModelTester: ...@@ -168,9 +162,8 @@ class XxxModelTester:
start_positions=sequence_labels, start_positions=sequence_labels,
end_positions=sequence_labels, end_positions=sequence_labels,
) )
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
self.check_loss_output(result)
def create_and_check_xxx_for_sequence_classification( def create_and_check_xxx_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -180,8 +173,7 @@ class XxxModelTester: ...@@ -180,8 +173,7 @@ class XxxModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
self.check_loss_output(result)
def create_and_check_xxx_for_token_classification( def create_and_check_xxx_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -191,8 +183,7 @@ class XxxModelTester: ...@@ -191,8 +183,7 @@ class XxxModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
......
...@@ -103,9 +103,6 @@ class AlbertModelTester: ...@@ -103,9 +103,6 @@ class AlbertModelTester:
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_albert_model( def create_and_check_albert_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
...@@ -115,10 +112,8 @@ class AlbertModelTester: ...@@ -115,10 +112,8 @@ class AlbertModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids) result = model(input_ids)
self.parent.assertListEqual( self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size] self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
)
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_albert_for_pretraining( def create_and_check_albert_for_pretraining(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -133,11 +128,8 @@ class AlbertModelTester: ...@@ -133,11 +128,8 @@ class AlbertModelTester:
labels=token_labels, labels=token_labels,
sentence_order_label=sequence_labels, sentence_order_label=sequence_labels,
) )
self.parent.assertListEqual( self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
list(result["prediction_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size] self.parent.assertEqual(result.sop_logits.shape, (self.batch_size, config.num_labels))
)
self.parent.assertListEqual(list(result["sop_logits"].size()), [self.batch_size, config.num_labels])
self.check_loss_output(result)
def create_and_check_albert_for_masked_lm( def create_and_check_albert_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -146,8 +138,7 @@ class AlbertModelTester: ...@@ -146,8 +138,7 @@ class AlbertModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.check_loss_output(result)
def create_and_check_albert_for_question_answering( def create_and_check_albert_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -162,9 +153,8 @@ class AlbertModelTester: ...@@ -162,9 +153,8 @@ class AlbertModelTester:
start_positions=sequence_labels, start_positions=sequence_labels,
end_positions=sequence_labels, end_positions=sequence_labels,
) )
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
self.check_loss_output(result)
def create_and_check_albert_for_sequence_classification( def create_and_check_albert_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -174,8 +164,7 @@ class AlbertModelTester: ...@@ -174,8 +164,7 @@ class AlbertModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
self.check_loss_output(result)
def create_and_check_albert_for_token_classification( def create_and_check_albert_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -185,8 +174,7 @@ class AlbertModelTester: ...@@ -185,8 +174,7 @@ class AlbertModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
self.check_loss_output(result)
def create_and_check_albert_for_multiple_choice( def create_and_check_albert_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -204,7 +192,7 @@ class AlbertModelTester: ...@@ -204,7 +192,7 @@ class AlbertModelTester:
token_type_ids=multiple_choice_token_type_ids, token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels, labels=choice_labels,
) )
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
......
...@@ -152,9 +152,6 @@ class BertModelTester: ...@@ -152,9 +152,6 @@ class BertModelTester:
encoder_attention_mask, encoder_attention_mask,
) )
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_bert_model( def create_and_check_bert_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
...@@ -164,10 +161,8 @@ class BertModelTester: ...@@ -164,10 +161,8 @@ class BertModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids) result = model(input_ids)
self.parent.assertListEqual( self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size] self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
)
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_bert_model_as_decoder( def create_and_check_bert_model_as_decoder(
self, self,
...@@ -198,10 +193,8 @@ class BertModelTester: ...@@ -198,10 +193,8 @@ class BertModelTester:
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
) )
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
self.parent.assertListEqual( self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size] self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
)
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_bert_for_causal_lm( def create_and_check_bert_for_causal_lm(
self, self,
...@@ -219,8 +212,7 @@ class BertModelTester: ...@@ -219,8 +212,7 @@ class BertModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.check_loss_output(result)
def create_and_check_bert_for_masked_lm( def create_and_check_bert_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -229,8 +221,7 @@ class BertModelTester: ...@@ -229,8 +221,7 @@ class BertModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.check_loss_output(result)
def create_and_check_bert_model_for_causal_lm_as_decoder( def create_and_check_bert_model_for_causal_lm_as_decoder(
self, self,
...@@ -262,8 +253,7 @@ class BertModelTester: ...@@ -262,8 +253,7 @@ class BertModelTester:
labels=token_labels, labels=token_labels,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
) )
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.check_loss_output(result)
def create_and_check_bert_for_next_sequence_prediction( def create_and_check_bert_for_next_sequence_prediction(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -274,8 +264,7 @@ class BertModelTester: ...@@ -274,8 +264,7 @@ class BertModelTester:
result = model( result = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, next_sentence_label=sequence_labels, input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, next_sentence_label=sequence_labels,
) )
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, 2]) self.parent.assertEqual(result.logits.shape, (self.batch_size, 2))
self.check_loss_output(result)
def create_and_check_bert_for_pretraining( def create_and_check_bert_for_pretraining(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -290,11 +279,8 @@ class BertModelTester: ...@@ -290,11 +279,8 @@ class BertModelTester:
labels=token_labels, labels=token_labels,
next_sentence_label=sequence_labels, next_sentence_label=sequence_labels,
) )
self.parent.assertListEqual( self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
list(result["prediction_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size] self.parent.assertEqual(result.seq_relationship_logits.shape, (self.batch_size, 2))
)
self.parent.assertListEqual(list(result["seq_relationship_logits"].size()), [self.batch_size, 2])
self.check_loss_output(result)
def create_and_check_bert_for_question_answering( def create_and_check_bert_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -309,9 +295,8 @@ class BertModelTester: ...@@ -309,9 +295,8 @@ class BertModelTester:
start_positions=sequence_labels, start_positions=sequence_labels,
end_positions=sequence_labels, end_positions=sequence_labels,
) )
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
self.check_loss_output(result)
def create_and_check_bert_for_sequence_classification( def create_and_check_bert_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -321,8 +306,7 @@ class BertModelTester: ...@@ -321,8 +306,7 @@ class BertModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
self.check_loss_output(result)
def create_and_check_bert_for_token_classification( def create_and_check_bert_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -332,8 +316,7 @@ class BertModelTester: ...@@ -332,8 +316,7 @@ class BertModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
self.check_loss_output(result)
def create_and_check_bert_for_multiple_choice( def create_and_check_bert_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -351,8 +334,7 @@ class BertModelTester: ...@@ -351,8 +334,7 @@ class BertModelTester:
token_type_ids=multiple_choice_token_type_ids, token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels, labels=choice_labels,
) )
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
......
...@@ -108,9 +108,6 @@ class CTRLModelTester: ...@@ -108,9 +108,6 @@ class CTRLModelTester:
choice_labels, choice_labels,
) )
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_ctrl_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): def create_and_check_ctrl_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = CTRLModel(config=config) model = CTRLModel(config=config)
model.to(torch_device) model.to(torch_device)
...@@ -119,9 +116,7 @@ class CTRLModelTester: ...@@ -119,9 +116,7 @@ class CTRLModelTester:
model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask) model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
model(input_ids, token_type_ids=token_type_ids) model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids) result = model(input_ids)
self.parent.assertListEqual( self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertEqual(len(result["past_key_values"]), config.n_layer) self.parent.assertEqual(len(result["past_key_values"]), config.n_layer)
def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
...@@ -130,8 +125,8 @@ class CTRLModelTester: ...@@ -130,8 +125,8 @@ class CTRLModelTester:
model.eval() model.eval()
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertEqual(result.loss.shape, ())
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
......
...@@ -115,9 +115,6 @@ if is_torch_available(): ...@@ -115,9 +115,6 @@ if is_torch_available():
return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_distilbert_model( def create_and_check_distilbert_model(
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
...@@ -126,8 +123,8 @@ if is_torch_available(): ...@@ -126,8 +123,8 @@ if is_torch_available():
model.eval() model.eval()
result = model(input_ids, input_mask) result = model(input_ids, input_mask)
result = model(input_ids) result = model(input_ids)
self.parent.assertListEqual( self.parent.assertEqual(
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size] result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)
) )
def create_and_check_distilbert_for_masked_lm( def create_and_check_distilbert_for_masked_lm(
...@@ -137,10 +134,7 @@ if is_torch_available(): ...@@ -137,10 +134,7 @@ if is_torch_available():
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels) result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertListEqual( self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.check_loss_output(result)
def create_and_check_distilbert_for_question_answering( def create_and_check_distilbert_for_question_answering(
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -151,9 +145,8 @@ if is_torch_available(): ...@@ -151,9 +145,8 @@ if is_torch_available():
result = model( result = model(
input_ids, attention_mask=input_mask, start_positions=sequence_labels, end_positions=sequence_labels input_ids, attention_mask=input_mask, start_positions=sequence_labels, end_positions=sequence_labels
) )
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
self.check_loss_output(result)
def create_and_check_distilbert_for_sequence_classification( def create_and_check_distilbert_for_sequence_classification(
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -163,8 +156,7 @@ if is_torch_available(): ...@@ -163,8 +156,7 @@ if is_torch_available():
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, labels=sequence_labels) result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
self.check_loss_output(result)
def create_and_check_distilbert_for_token_classification( def create_and_check_distilbert_for_token_classification(
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -175,10 +167,7 @@ if is_torch_available(): ...@@ -175,10 +167,7 @@ if is_torch_available():
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels) result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertListEqual( self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]
)
self.check_loss_output(result)
def create_and_check_distilbert_for_multiple_choice( def create_and_check_distilbert_for_multiple_choice(
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -192,8 +181,7 @@ if is_torch_available(): ...@@ -192,8 +181,7 @@ if is_torch_available():
result = model( result = model(
multiple_choice_inputs_ids, attention_mask=multiple_choice_input_mask, labels=choice_labels, multiple_choice_inputs_ids, attention_mask=multiple_choice_input_mask, labels=choice_labels,
) )
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
......
...@@ -130,9 +130,7 @@ class DPRModelTester: ...@@ -130,9 +130,7 @@ class DPRModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids) result = model(input_ids)
self.parent.assertListEqual( self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.projection_dim or self.hidden_size))
list(result["pooler_output"].size()), [self.batch_size, self.projection_dim or self.hidden_size]
)
def create_and_check_dpr_question_encoder( def create_and_check_dpr_question_encoder(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -143,9 +141,7 @@ class DPRModelTester: ...@@ -143,9 +141,7 @@ class DPRModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids) result = model(input_ids)
self.parent.assertListEqual( self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.projection_dim or self.hidden_size))
list(result["pooler_output"].size()), [self.batch_size, self.projection_dim or self.hidden_size]
)
def create_and_check_dpr_reader( def create_and_check_dpr_reader(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -154,9 +150,10 @@ class DPRModelTester: ...@@ -154,9 +150,10 @@ class DPRModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask,) result = model(input_ids, attention_mask=input_mask,)
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertListEqual(list(result["relevance_logits"].size()), [self.batch_size]) self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.relevance_logits.shape, (self.batch_size,))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
......
...@@ -111,9 +111,6 @@ class ElectraModelTester: ...@@ -111,9 +111,6 @@ class ElectraModelTester:
fake_token_labels, fake_token_labels,
) )
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_electra_model( def create_and_check_electra_model(
self, self,
config, config,
...@@ -131,9 +128,7 @@ class ElectraModelTester: ...@@ -131,9 +128,7 @@ class ElectraModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids) result = model(input_ids)
self.parent.assertListEqual( self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
)
def create_and_check_electra_for_masked_lm( def create_and_check_electra_for_masked_lm(
self, self,
...@@ -150,8 +145,7 @@ class ElectraModelTester: ...@@ -150,8 +145,7 @@ class ElectraModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.check_loss_output(result)
def create_and_check_electra_for_token_classification( def create_and_check_electra_for_token_classification(
self, self,
...@@ -169,8 +163,7 @@ class ElectraModelTester: ...@@ -169,8 +163,7 @@ class ElectraModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
self.check_loss_output(result)
def create_and_check_electra_for_pretraining( def create_and_check_electra_for_pretraining(
self, self,
...@@ -188,8 +181,7 @@ class ElectraModelTester: ...@@ -188,8 +181,7 @@ class ElectraModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=fake_token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=fake_token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length))
self.check_loss_output(result)
def create_and_check_electra_for_sequence_classification( def create_and_check_electra_for_sequence_classification(
self, self,
...@@ -207,8 +199,7 @@ class ElectraModelTester: ...@@ -207,8 +199,7 @@ class ElectraModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
self.check_loss_output(result)
def create_and_check_electra_for_question_answering( def create_and_check_electra_for_question_answering(
self, self,
...@@ -231,9 +222,8 @@ class ElectraModelTester: ...@@ -231,9 +222,8 @@ class ElectraModelTester:
start_positions=sequence_labels, start_positions=sequence_labels,
end_positions=sequence_labels, end_positions=sequence_labels,
) )
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
self.check_loss_output(result)
def create_and_check_electra_for_multiple_choice( def create_and_check_electra_for_multiple_choice(
self, self,
...@@ -259,8 +249,7 @@ class ElectraModelTester: ...@@ -259,8 +249,7 @@ class ElectraModelTester:
token_type_ids=multiple_choice_token_type_ids, token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels, labels=choice_labels,
) )
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
......
...@@ -253,9 +253,6 @@ class EncoderDecoderModelTest(unittest.TestCase): ...@@ -253,9 +253,6 @@ class EncoderDecoderModelTest(unittest.TestCase):
max_diff = np.amax(np.abs(out_1 - out_2)) max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5) self.assertLessEqual(max_diff, 1e-5)
def check_loss_output(self, loss):
self.assertEqual(loss.size(), ())
def create_and_check_bert_encoder_decoder_model_labels( def create_and_check_bert_encoder_decoder_model_labels(
self, self,
config, config,
...@@ -281,7 +278,6 @@ class EncoderDecoderModelTest(unittest.TestCase): ...@@ -281,7 +278,6 @@ class EncoderDecoderModelTest(unittest.TestCase):
) )
mlm_loss = outputs_encoder_decoder[0] mlm_loss = outputs_encoder_decoder[0]
self.check_loss_output(mlm_loss)
# check that backprop works # check that backprop works
mlm_loss.backward() mlm_loss.backward()
......
...@@ -125,9 +125,6 @@ class FlaubertModelTester(object): ...@@ -125,9 +125,6 @@ class FlaubertModelTester(object):
input_mask, input_mask,
) )
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_flaubert_model( def create_and_check_flaubert_model(
self, self,
config, config,
...@@ -146,9 +143,7 @@ class FlaubertModelTester(object): ...@@ -146,9 +143,7 @@ class FlaubertModelTester(object):
result = model(input_ids, lengths=input_lengths, langs=token_type_ids) result = model(input_ids, lengths=input_lengths, langs=token_type_ids)
result = model(input_ids, langs=token_type_ids) result = model(input_ids, langs=token_type_ids)
result = model(input_ids) result = model(input_ids)
self.parent.assertListEqual( self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
)
def create_and_check_flaubert_lm_head( def create_and_check_flaubert_lm_head(
self, self,
...@@ -167,8 +162,8 @@ class FlaubertModelTester(object): ...@@ -167,8 +162,8 @@ class FlaubertModelTester(object):
model.eval() model.eval()
result = model(input_ids, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertEqual(result.loss.shape, ())
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_flaubert_simple_qa( def create_and_check_flaubert_simple_qa(
self, self,
...@@ -189,9 +184,8 @@ class FlaubertModelTester(object): ...@@ -189,9 +184,8 @@ class FlaubertModelTester(object):
result = model(input_ids) result = model(input_ids)
result = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels) result = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
self.check_loss_output(result)
def create_and_check_flaubert_qa( def create_and_check_flaubert_qa(
self, self,
...@@ -234,21 +228,16 @@ class FlaubertModelTester(object): ...@@ -234,21 +228,16 @@ class FlaubertModelTester(object):
(total_loss,) = result_with_labels.to_tuple() (total_loss,) = result_with_labels.to_tuple()
self.parent.assertListEqual(list(result_with_labels["loss"].size()), []) self.parent.assertEqual(result_with_labels.loss.shape, ())
self.parent.assertListEqual( self.parent.assertEqual(result.start_top_log_probs.shape, (self.batch_size, model.config.start_n_top))
list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top] self.parent.assertEqual(result.start_top_index.shape, (self.batch_size, model.config.start_n_top))
) self.parent.assertEqual(
self.parent.assertListEqual( result.end_top_log_probs.shape, (self.batch_size, model.config.start_n_top * model.config.end_n_top)
list(result["start_top_index"].size()), [self.batch_size, model.config.start_n_top]
)
self.parent.assertListEqual(
list(result["end_top_log_probs"].size()),
[self.batch_size, model.config.start_n_top * model.config.end_n_top],
) )
self.parent.assertListEqual( self.parent.assertEqual(
list(result["end_top_index"].size()), [self.batch_size, model.config.start_n_top * model.config.end_n_top], result.end_top_index.shape, (self.batch_size, model.config.start_n_top * model.config.end_n_top)
) )
self.parent.assertListEqual(list(result["cls_logits"].size()), [self.batch_size]) self.parent.assertEqual(result.cls_logits.shape, (self.batch_size,))
def create_and_check_flaubert_sequence_classif( def create_and_check_flaubert_sequence_classif(
self, self,
...@@ -269,8 +258,8 @@ class FlaubertModelTester(object): ...@@ -269,8 +258,8 @@ class FlaubertModelTester(object):
result = model(input_ids) result = model(input_ids)
result = model(input_ids, labels=sequence_labels) result = model(input_ids, labels=sequence_labels)
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertEqual(result.loss.shape, ())
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def create_and_check_flaubert_token_classif( def create_and_check_flaubert_token_classif(
self, self,
...@@ -290,8 +279,7 @@ class FlaubertModelTester(object): ...@@ -290,8 +279,7 @@ class FlaubertModelTester(object):
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels) result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
self.check_loss_output(result)
def create_and_check_flaubert_multiple_choice( def create_and_check_flaubert_multiple_choice(
self, self,
...@@ -318,8 +306,7 @@ class FlaubertModelTester(object): ...@@ -318,8 +306,7 @@ class FlaubertModelTester(object):
token_type_ids=multiple_choice_token_type_ids, token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels, labels=choice_labels,
) )
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
......
...@@ -142,9 +142,6 @@ class GPT2ModelTester: ...@@ -142,9 +142,6 @@ class GPT2ModelTester:
choice_labels, choice_labels,
) )
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_gpt2_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): def create_and_check_gpt2_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = GPT2Model(config=config) model = GPT2Model(config=config)
model.to(torch_device) model.to(torch_device)
...@@ -154,9 +151,7 @@ class GPT2ModelTester: ...@@ -154,9 +151,7 @@ class GPT2ModelTester:
result = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids) result = model(input_ids)
self.parent.assertListEqual( self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size],
)
self.parent.assertEqual(len(result["past_key_values"]), config.n_layer) self.parent.assertEqual(len(result["past_key_values"]), config.n_layer)
def create_and_check_gpt2_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): def create_and_check_gpt2_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
...@@ -240,10 +235,8 @@ class GPT2ModelTester: ...@@ -240,10 +235,8 @@ class GPT2ModelTester:
model.eval() model.eval()
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertEqual(result.loss.shape, ())
self.parent.assertListEqual( self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
def create_and_check_double_lm_head_model( def create_and_check_double_lm_head_model(
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args
...@@ -265,11 +258,11 @@ class GPT2ModelTester: ...@@ -265,11 +258,11 @@ class GPT2ModelTester:
} }
result = model(**inputs) result = model(**inputs)
self.parent.assertListEqual(list(result["lm_loss"].size()), []) self.parent.assertEqual(result.lm_loss.shape, ())
self.parent.assertListEqual( self.parent.assertEqual(
list(result["lm_logits"].size()), [self.batch_size, self.num_choices, self.seq_length, self.vocab_size], result.lm_logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size)
) )
self.parent.assertListEqual(list(result["mc_logits"].size()), [self.batch_size, self.num_choices]) self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
......
...@@ -113,9 +113,6 @@ class LongformerModelTester: ...@@ -113,9 +113,6 @@ class LongformerModelTester:
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_attention_mask_determinism( def create_and_check_attention_mask_determinism(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
...@@ -137,10 +134,8 @@ class LongformerModelTester: ...@@ -137,10 +134,8 @@ class LongformerModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids) result = model(input_ids)
self.parent.assertListEqual( self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size] self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
)
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_longformer_model_with_global_attention_mask( def create_and_check_longformer_model_with_global_attention_mask(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -161,10 +156,8 @@ class LongformerModelTester: ...@@ -161,10 +156,8 @@ class LongformerModelTester:
result = model(input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask) result = model(input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask)
result = model(input_ids, global_attention_mask=global_attention_mask) result = model(input_ids, global_attention_mask=global_attention_mask)
self.parent.assertListEqual( self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size] self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
)
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_longformer_for_masked_lm( def create_and_check_longformer_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -173,8 +166,7 @@ class LongformerModelTester: ...@@ -173,8 +166,7 @@ class LongformerModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.check_loss_output(result)
def create_and_check_longformer_for_question_answering( def create_and_check_longformer_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -190,9 +182,8 @@ class LongformerModelTester: ...@@ -190,9 +182,8 @@ class LongformerModelTester:
start_positions=sequence_labels, start_positions=sequence_labels,
end_positions=sequence_labels, end_positions=sequence_labels,
) )
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
self.check_loss_output(result)
def create_and_check_longformer_for_sequence_classification( def create_and_check_longformer_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -202,8 +193,7 @@ class LongformerModelTester: ...@@ -202,8 +193,7 @@ class LongformerModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
self.check_loss_output(result)
def create_and_check_longformer_for_token_classification( def create_and_check_longformer_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -213,8 +203,7 @@ class LongformerModelTester: ...@@ -213,8 +203,7 @@ class LongformerModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
self.check_loss_output(result)
def create_and_check_longformer_for_multiple_choice( def create_and_check_longformer_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -234,8 +223,7 @@ class LongformerModelTester: ...@@ -234,8 +223,7 @@ class LongformerModelTester:
token_type_ids=multiple_choice_token_type_ids, token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels, labels=choice_labels,
) )
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
......
...@@ -154,9 +154,6 @@ class MobileBertModelTester: ...@@ -154,9 +154,6 @@ class MobileBertModelTester:
encoder_attention_mask, encoder_attention_mask,
) )
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_mobilebert_model( def create_and_check_mobilebert_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
...@@ -167,10 +164,8 @@ class MobileBertModelTester: ...@@ -167,10 +164,8 @@ class MobileBertModelTester:
result = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids) result = model(input_ids)
self.parent.assertListEqual( self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size] self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
)
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_mobilebert_model_as_decoder( def create_and_check_mobilebert_model_as_decoder(
self, self,
...@@ -202,10 +197,8 @@ class MobileBertModelTester: ...@@ -202,10 +197,8 @@ class MobileBertModelTester:
) )
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
self.parent.assertListEqual( self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size] self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
)
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_mobilebert_for_masked_lm( def create_and_check_mobilebert_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -214,8 +207,7 @@ class MobileBertModelTester: ...@@ -214,8 +207,7 @@ class MobileBertModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.check_loss_output(result)
def create_and_check_mobilebert_for_next_sequence_prediction( def create_and_check_mobilebert_for_next_sequence_prediction(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -226,8 +218,7 @@ class MobileBertModelTester: ...@@ -226,8 +218,7 @@ class MobileBertModelTester:
result = model( result = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, next_sentence_label=sequence_labels, input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, next_sentence_label=sequence_labels,
) )
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, 2]) self.parent.assertEqual(result.logits.shape, (self.batch_size, 2))
self.check_loss_output(result)
def create_and_check_mobilebert_for_pretraining( def create_and_check_mobilebert_for_pretraining(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -242,11 +233,8 @@ class MobileBertModelTester: ...@@ -242,11 +233,8 @@ class MobileBertModelTester:
labels=token_labels, labels=token_labels,
next_sentence_label=sequence_labels, next_sentence_label=sequence_labels,
) )
self.parent.assertListEqual( self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
list(result["prediction_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size] self.parent.assertEqual(result.seq_relationship_logits.shape, (self.batch_size, 2))
)
self.parent.assertListEqual(list(result["seq_relationship_logits"].size()), [self.batch_size, 2])
self.check_loss_output(result)
def create_and_check_mobilebert_for_question_answering( def create_and_check_mobilebert_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -261,9 +249,8 @@ class MobileBertModelTester: ...@@ -261,9 +249,8 @@ class MobileBertModelTester:
start_positions=sequence_labels, start_positions=sequence_labels,
end_positions=sequence_labels, end_positions=sequence_labels,
) )
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
self.check_loss_output(result)
def create_and_check_mobilebert_for_sequence_classification( def create_and_check_mobilebert_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -273,8 +260,7 @@ class MobileBertModelTester: ...@@ -273,8 +260,7 @@ class MobileBertModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
self.check_loss_output(result)
def create_and_check_mobilebert_for_token_classification( def create_and_check_mobilebert_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -284,8 +270,7 @@ class MobileBertModelTester: ...@@ -284,8 +270,7 @@ class MobileBertModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
self.check_loss_output(result)
def create_and_check_mobilebert_for_multiple_choice( def create_and_check_mobilebert_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -303,8 +288,7 @@ class MobileBertModelTester: ...@@ -303,8 +288,7 @@ class MobileBertModelTester:
token_type_ids=multiple_choice_token_type_ids, token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels, labels=choice_labels,
) )
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
......
...@@ -103,9 +103,6 @@ class OpenAIGPTModelTester: ...@@ -103,9 +103,6 @@ class OpenAIGPTModelTester:
choice_labels, choice_labels,
) )
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_openai_gpt_model(self, config, input_ids, head_mask, token_type_ids, *args): def create_and_check_openai_gpt_model(self, config, input_ids, head_mask, token_type_ids, *args):
model = OpenAIGPTModel(config=config) model = OpenAIGPTModel(config=config)
model.to(torch_device) model.to(torch_device)
...@@ -115,9 +112,7 @@ class OpenAIGPTModelTester: ...@@ -115,9 +112,7 @@ class OpenAIGPTModelTester:
result = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids) result = model(input_ids)
self.parent.assertListEqual( self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size],
)
def create_and_check_lm_head_model(self, config, input_ids, head_mask, token_type_ids, *args): def create_and_check_lm_head_model(self, config, input_ids, head_mask, token_type_ids, *args):
model = OpenAIGPTLMHeadModel(config) model = OpenAIGPTLMHeadModel(config)
...@@ -125,10 +120,8 @@ class OpenAIGPTModelTester: ...@@ -125,10 +120,8 @@ class OpenAIGPTModelTester:
model.eval() model.eval()
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertEqual(result.loss.shape, ())
self.parent.assertListEqual( self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
def create_and_check_double_lm_head_model(self, config, input_ids, head_mask, token_type_ids, *args): def create_and_check_double_lm_head_model(self, config, input_ids, head_mask, token_type_ids, *args):
model = OpenAIGPTDoubleHeadsModel(config) model = OpenAIGPTDoubleHeadsModel(config)
...@@ -136,10 +129,8 @@ class OpenAIGPTModelTester: ...@@ -136,10 +129,8 @@ class OpenAIGPTModelTester:
model.eval() model.eval()
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
self.parent.assertListEqual(list(result["lm_loss"].size()), []) self.parent.assertEqual(result.lm_loss.shape, ())
self.parent.assertListEqual( self.parent.assertEqual(result.lm_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
list(result["lm_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
......
...@@ -175,9 +175,6 @@ class ReformerModelTester: ...@@ -175,9 +175,6 @@ class ReformerModelTester:
choice_labels, choice_labels,
) )
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_reformer_model(self, config, input_ids, input_mask, choice_labels): def create_and_check_reformer_model(self, config, input_ids, input_mask, choice_labels):
model = ReformerModel(config=config) model = ReformerModel(config=config)
model.to(torch_device) model.to(torch_device)
...@@ -186,8 +183,8 @@ class ReformerModelTester: ...@@ -186,8 +183,8 @@ class ReformerModelTester:
result = model(input_ids) result = model(input_ids)
# 2 * hidden_size because we use reversible resnet layers # 2 * hidden_size because we use reversible resnet layers
self.parent.assertListEqual( self.parent.assertEqual(
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size], result.last_hidden_state.shape, (self.batch_size, self.seq_length, 2 * self.hidden_size)
) )
def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels): def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels):
...@@ -206,10 +203,7 @@ class ReformerModelTester: ...@@ -206,10 +203,7 @@ class ReformerModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, labels=input_ids) result = model(input_ids, attention_mask=input_mask, labels=input_ids)
self.parent.assertListEqual( self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
self.check_loss_output(result)
def create_and_check_reformer_with_mlm(self, config, input_ids, input_mask, choice_labels): def create_and_check_reformer_with_mlm(self, config, input_ids, input_mask, choice_labels):
config.is_decoder = False config.is_decoder = False
...@@ -217,10 +211,7 @@ class ReformerModelTester: ...@@ -217,10 +211,7 @@ class ReformerModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, labels=input_ids) result = model(input_ids, attention_mask=input_mask, labels=input_ids)
self.parent.assertListEqual( self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
self.check_loss_output(result)
def create_and_check_reformer_model_with_attn_mask( def create_and_check_reformer_model_with_attn_mask(
self, config, input_ids, input_mask, choice_labels, is_decoder=False self, config, input_ids, input_mask, choice_labels, is_decoder=False
...@@ -444,9 +435,8 @@ class ReformerModelTester: ...@@ -444,9 +435,8 @@ class ReformerModelTester:
result = model( result = model(
input_ids, attention_mask=input_mask, start_positions=choice_labels, end_positions=choice_labels, input_ids, attention_mask=input_mask, start_positions=choice_labels, end_positions=choice_labels,
) )
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
self.check_loss_output(result)
def create_and_check_past_buckets_states(self, config, input_ids, input_mask, choice_labels): def create_and_check_past_buckets_states(self, config, input_ids, input_mask, choice_labels):
config.is_decoder = True config.is_decoder = True
...@@ -490,8 +480,7 @@ class ReformerModelTester: ...@@ -490,8 +480,7 @@ class ReformerModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, labels=sequence_labels) result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
self.check_loss_output(result)
class ReformerTesterMixin: class ReformerTesterMixin:
......
...@@ -101,9 +101,6 @@ class RobertaModelTester: ...@@ -101,9 +101,6 @@ class RobertaModelTester:
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_roberta_model( def create_and_check_roberta_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
...@@ -114,10 +111,8 @@ class RobertaModelTester: ...@@ -114,10 +111,8 @@ class RobertaModelTester:
result = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids) result = model(input_ids)
self.parent.assertListEqual( self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size] self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
)
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_roberta_for_masked_lm( def create_and_check_roberta_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -126,8 +121,7 @@ class RobertaModelTester: ...@@ -126,8 +121,7 @@ class RobertaModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.check_loss_output(result)
def create_and_check_roberta_for_token_classification( def create_and_check_roberta_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -137,8 +131,7 @@ class RobertaModelTester: ...@@ -137,8 +131,7 @@ class RobertaModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
self.check_loss_output(result)
def create_and_check_roberta_for_multiple_choice( def create_and_check_roberta_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -156,8 +149,7 @@ class RobertaModelTester: ...@@ -156,8 +149,7 @@ class RobertaModelTester:
token_type_ids=multiple_choice_token_type_ids, token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels, labels=choice_labels,
) )
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
self.check_loss_output(result)
def create_and_check_roberta_for_question_answering( def create_and_check_roberta_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -172,9 +164,8 @@ class RobertaModelTester: ...@@ -172,9 +164,8 @@ class RobertaModelTester:
start_positions=sequence_labels, start_positions=sequence_labels,
end_positions=sequence_labels, end_positions=sequence_labels,
) )
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
......
...@@ -95,9 +95,6 @@ class T5ModelTester: ...@@ -95,9 +95,6 @@ class T5ModelTester:
lm_labels, lm_labels,
) )
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def check_prepare_lm_labels_via_shift_left( def check_prepare_lm_labels_via_shift_left(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
): ):
......
...@@ -128,9 +128,6 @@ class XLMModelTester: ...@@ -128,9 +128,6 @@ class XLMModelTester:
input_mask, input_mask,
) )
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_xlm_model( def create_and_check_xlm_model(
self, self,
config, config,
...@@ -149,9 +146,7 @@ class XLMModelTester: ...@@ -149,9 +146,7 @@ class XLMModelTester:
result = model(input_ids, lengths=input_lengths, langs=token_type_ids) result = model(input_ids, lengths=input_lengths, langs=token_type_ids)
result = model(input_ids, langs=token_type_ids) result = model(input_ids, langs=token_type_ids)
result = model(input_ids) result = model(input_ids)
self.parent.assertListEqual( self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
)
def create_and_check_xlm_lm_head( def create_and_check_xlm_lm_head(
self, self,
...@@ -170,8 +165,8 @@ class XLMModelTester: ...@@ -170,8 +165,8 @@ class XLMModelTester:
model.eval() model.eval()
result = model(input_ids, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertEqual(result.loss.shape, ())
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_xlm_simple_qa( def create_and_check_xlm_simple_qa(
self, self,
...@@ -193,9 +188,8 @@ class XLMModelTester: ...@@ -193,9 +188,8 @@ class XLMModelTester:
outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels) outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
result = outputs result = outputs
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
self.check_loss_output(result)
def create_and_check_xlm_qa( def create_and_check_xlm_qa(
self, self,
...@@ -238,21 +232,16 @@ class XLMModelTester: ...@@ -238,21 +232,16 @@ class XLMModelTester:
(total_loss,) = result_with_labels.to_tuple() (total_loss,) = result_with_labels.to_tuple()
self.parent.assertListEqual(list(result_with_labels["loss"].size()), []) self.parent.assertEqual(result_with_labels.loss.shape, ())
self.parent.assertListEqual( self.parent.assertEqual(result.start_top_log_probs.shape, (self.batch_size, model.config.start_n_top))
list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top] self.parent.assertEqual(result.start_top_index.shape, (self.batch_size, model.config.start_n_top))
) self.parent.assertEqual(
self.parent.assertListEqual( result.end_top_log_probs.shape, (self.batch_size, model.config.start_n_top * model.config.end_n_top)
list(result["start_top_index"].size()), [self.batch_size, model.config.start_n_top]
)
self.parent.assertListEqual(
list(result["end_top_log_probs"].size()),
[self.batch_size, model.config.start_n_top * model.config.end_n_top],
) )
self.parent.assertListEqual( self.parent.assertEqual(
list(result["end_top_index"].size()), [self.batch_size, model.config.start_n_top * model.config.end_n_top], result.end_top_index.shape, (self.batch_size, model.config.start_n_top * model.config.end_n_top)
) )
self.parent.assertListEqual(list(result["cls_logits"].size()), [self.batch_size]) self.parent.assertEqual(result.cls_logits.shape, (self.batch_size,))
def create_and_check_xlm_sequence_classif( def create_and_check_xlm_sequence_classif(
self, self,
...@@ -272,8 +261,8 @@ class XLMModelTester: ...@@ -272,8 +261,8 @@ class XLMModelTester:
result = model(input_ids) result = model(input_ids)
result = model(input_ids, labels=sequence_labels) result = model(input_ids, labels=sequence_labels)
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertEqual(result.loss.shape, ())
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def create_and_check_xlm_token_classif( def create_and_check_xlm_token_classif(
self, self,
...@@ -293,8 +282,7 @@ class XLMModelTester: ...@@ -293,8 +282,7 @@ class XLMModelTester:
model.eval() model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels) result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
self.check_loss_output(result)
def create_and_check_xlm_for_multiple_choice( def create_and_check_xlm_for_multiple_choice(
self, self,
...@@ -321,8 +309,7 @@ class XLMModelTester: ...@@ -321,8 +309,7 @@ class XLMModelTester:
token_type_ids=multiple_choice_token_type_ids, token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels, labels=choice_labels,
) )
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices]) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
......
...@@ -190,9 +190,7 @@ class XLNetModelTester: ...@@ -190,9 +190,7 @@ class XLNetModelTester:
base_model_output = model(input_ids_1) base_model_output = model(input_ids_1)
self.parent.assertEqual(len(base_model_output), 2) self.parent.assertEqual(len(base_model_output), 2)
self.parent.assertListEqual( self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size],
)
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems"]), list(list(mem.size()) for mem in result["mems"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers, [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
...@@ -311,19 +309,15 @@ class XLNetModelTester: ...@@ -311,19 +309,15 @@ class XLNetModelTester:
_ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping) _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping)
self.parent.assertListEqual(list(result1["loss"].size()), []) self.parent.assertEqual(result1.loss.shape, ())
self.parent.assertListEqual( self.parent.assertEqual(result1.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
list(result1["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result1["mems"]), list(list(mem.size()) for mem in result1["mems"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers, [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
) )
self.parent.assertListEqual(list(result2["loss"].size()), []) self.parent.assertEqual(result2.loss.shape, ())
self.parent.assertListEqual( self.parent.assertEqual(result2.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
list(result2["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result2["mems"]), list(list(mem.size()) for mem in result2["mems"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers, [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
...@@ -373,21 +367,16 @@ class XLNetModelTester: ...@@ -373,21 +367,16 @@ class XLNetModelTester:
total_loss, mems = result_with_labels.to_tuple() total_loss, mems = result_with_labels.to_tuple()
self.parent.assertListEqual(list(result_with_labels["loss"].size()), []) self.parent.assertEqual(result_with_labels.loss.shape, ())
self.parent.assertListEqual( self.parent.assertEqual(result.start_top_log_probs.shape, (self.batch_size, model.config.start_n_top))
list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top], self.parent.assertEqual(result.start_top_index.shape, (self.batch_size, model.config.start_n_top))
) self.parent.assertEqual(
self.parent.assertListEqual( result.end_top_log_probs.shape, (self.batch_size, model.config.start_n_top * model.config.end_n_top)
list(result["start_top_index"].size()), [self.batch_size, model.config.start_n_top],
)
self.parent.assertListEqual(
list(result["end_top_log_probs"].size()),
[self.batch_size, model.config.start_n_top * model.config.end_n_top],
) )
self.parent.assertListEqual( self.parent.assertEqual(
list(result["end_top_index"].size()), [self.batch_size, model.config.start_n_top * model.config.end_n_top], result.end_top_index.shape, (self.batch_size, model.config.start_n_top * model.config.end_n_top)
) )
self.parent.assertListEqual(list(result["cls_logits"].size()), [self.batch_size]) self.parent.assertEqual(result.cls_logits.shape, (self.batch_size,))
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems"]), list(list(mem.size()) for mem in result["mems"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers, [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
...@@ -415,10 +404,8 @@ class XLNetModelTester: ...@@ -415,10 +404,8 @@ class XLNetModelTester:
result = model(input_ids_1) result = model(input_ids_1)
result = model(input_ids_1, labels=token_labels) result = model(input_ids_1, labels=token_labels)
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertEqual(result.loss.shape, ())
self.parent.assertListEqual( self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.type_sequence_label_size))
list(result["logits"].size()), [self.batch_size, self.seq_length, self.type_sequence_label_size],
)
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems"]), list(list(mem.size()) for mem in result["mems"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers, [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
...@@ -446,10 +433,8 @@ class XLNetModelTester: ...@@ -446,10 +433,8 @@ class XLNetModelTester:
result = model(input_ids_1) result = model(input_ids_1)
result = model(input_ids_1, labels=sequence_labels) result = model(input_ids_1, labels=sequence_labels)
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertEqual(result.loss.shape, ())
self.parent.assertListEqual( self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size],
)
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems"]), list(list(mem.size()) for mem in result["mems"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers, [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment