Unverified Commit d951c14a authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Model output test (#6155)

* Use return_dict=True in all tests

* Formatting
parent 86caab1e
......@@ -273,6 +273,7 @@ class EncoderDecoderModel(PreTrainedModel):
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
return_dict=False,
**kwargs_encoder,
)
......@@ -287,6 +288,7 @@ class EncoderDecoderModel(PreTrainedModel):
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
labels=labels,
return_dict=False,
**kwargs_decoder,
)
......
......@@ -688,16 +688,15 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
lm_loss = None
lm_loss, mc_loss = None, None
if mc_labels is not None:
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
mc_loss = None
mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
if labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
mc_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (lm_logits, mc_logits) + transformer_outputs[1:]
......
......@@ -2386,6 +2386,7 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.reformer(
input_ids,
......
......@@ -121,6 +121,7 @@ class XxxModelTester:
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
return_dict=True,
)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
......@@ -134,18 +135,13 @@ class XxxModelTester:
model = XxxModel(config=config)
model.to(torch_device)
model.eval()
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids)
self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_xxx_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
......@@ -153,16 +149,10 @@ class XxxModelTester:
model = XxxForMaskedLM(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(
result = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
self.check_loss_output(result)
def create_and_check_xxx_for_question_answering(
......@@ -171,18 +161,13 @@ class XxxModelTester:
model = XxxForQuestionAnswering(config=config)
model.to(torch_device)
model.eval()
loss, start_logits, end_logits = model(
result = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result)
......@@ -194,13 +179,7 @@ class XxxModelTester:
model = XxxForSequenceClassification(config)
model.to(torch_device)
model.eval()
loss, logits = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
)
result = {
"loss": loss,
"logits": logits,
}
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
self.check_loss_output(result)
......@@ -211,11 +190,7 @@ class XxxModelTester:
model = XxxForTokenClassification(config=config)
model.to(torch_device)
model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result)
......
......@@ -98,6 +98,7 @@ class AlbertModelTester:
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
num_hidden_groups=self.num_hidden_groups,
return_dict=True,
)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
......@@ -111,18 +112,13 @@ class AlbertModelTester:
model = AlbertModel(config=config)
model.to(torch_device)
model.eval()
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids)
self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_albert_for_pretraining(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
......@@ -130,22 +126,17 @@ class AlbertModelTester:
model = AlbertForPreTraining(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores, sop_scores = model(
result = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
labels=token_labels,
sentence_order_label=sequence_labels,
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
"sop_scores": sop_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
list(result["prediction_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.parent.assertListEqual(list(result["sop_scores"].size()), [self.batch_size, config.num_labels])
self.parent.assertListEqual(list(result["sop_logits"].size()), [self.batch_size, config.num_labels])
self.check_loss_output(result)
def create_and_check_albert_for_masked_lm(
......@@ -154,16 +145,8 @@ class AlbertModelTester:
model = AlbertForMaskedLM(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
self.check_loss_output(result)
def create_and_check_albert_for_question_answering(
......@@ -172,18 +155,13 @@ class AlbertModelTester:
model = AlbertForQuestionAnswering(config=config)
model.to(torch_device)
model.eval()
loss, start_logits, end_logits = model(
result = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result)
......@@ -195,13 +173,7 @@ class AlbertModelTester:
model = AlbertForSequenceClassification(config)
model.to(torch_device)
model.eval()
loss, logits = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
)
result = {
"loss": loss,
"logits": logits,
}
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
self.check_loss_output(result)
......@@ -212,11 +184,7 @@ class AlbertModelTester:
model = AlbertForTokenClassification(config=config)
model.to(torch_device)
model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result)
......@@ -230,16 +198,12 @@ class AlbertModelTester:
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model(
result = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels,
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
def prepare_config_and_inputs_for_common(self):
......
......@@ -238,6 +238,7 @@ class BartHeadTests(unittest.TestCase):
eos_token_id=2,
pad_token_id=1,
bos_token_id=0,
return_dict=True,
)
return config, input_ids, batch_size
......@@ -247,24 +248,20 @@ class BartHeadTests(unittest.TestCase):
model = BartForSequenceClassification(config)
model.to(torch_device)
outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=labels)
logits = outputs[1]
expected_shape = torch.Size((batch_size, config.num_labels))
self.assertEqual(logits.shape, expected_shape)
loss = outputs[0]
self.assertIsInstance(loss.item(), float)
self.assertEqual(outputs["logits"].shape, expected_shape)
self.assertIsInstance(outputs["loss"].item(), float)
def test_question_answering_forward(self):
config, input_ids, batch_size = self._get_config_and_data()
sequence_labels = ids_tensor([batch_size], 2).to(torch_device)
model = BartForQuestionAnswering(config)
model.to(torch_device)
loss, start_logits, end_logits, _ = model(
input_ids=input_ids, start_positions=sequence_labels, end_positions=sequence_labels,
)
outputs = model(input_ids=input_ids, start_positions=sequence_labels, end_positions=sequence_labels,)
self.assertEqual(start_logits.shape, input_ids.shape)
self.assertEqual(end_logits.shape, input_ids.shape)
self.assertIsInstance(loss.item(), float)
self.assertEqual(outputs["start_logits"].shape, input_ids.shape)
self.assertEqual(outputs["end_logits"].shape, input_ids.shape)
self.assertIsInstance(outputs["loss"].item(), float)
@timeout_decorator.timeout(1)
def test_lm_forward(self):
......@@ -272,10 +269,10 @@ class BartHeadTests(unittest.TestCase):
lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
lm_model = BartForConditionalGeneration(config)
lm_model.to(torch_device)
loss, logits, enc_features = lm_model(input_ids=input_ids, labels=lm_labels)
outputs = lm_model(input_ids=input_ids, labels=lm_labels)
expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
self.assertEqual(logits.shape, expected_shape)
self.assertIsInstance(loss.item(), float)
self.assertEqual(outputs["logits"].shape, expected_shape)
self.assertIsInstance(outputs["loss"].item(), float)
def test_lm_uneven_forward(self):
config = BartConfig(
......@@ -288,13 +285,14 @@ class BartHeadTests(unittest.TestCase):
encoder_ffn_dim=8,
decoder_ffn_dim=8,
max_position_embeddings=48,
return_dict=True,
)
lm_model = BartForConditionalGeneration(config).to(torch_device)
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
loss, logits, enc_features = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
outputs = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
expected_shape = (*summary.shape, config.vocab_size)
self.assertEqual(logits.shape, expected_shape)
self.assertEqual(outputs["logits"].shape, expected_shape)
def test_generate_beam_search(self):
input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long().to(torch_device)
......
......@@ -120,6 +120,7 @@ class BertModelTester:
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
return_dict=True,
)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
......@@ -160,18 +161,13 @@ class BertModelTester:
model = BertModel(config=config)
model.to(torch_device)
model.eval()
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids)
self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_bert_model_as_decoder(
self,
......@@ -188,29 +184,24 @@ class BertModelTester:
model = BertModel(config)
model.to(torch_device)
model.eval()
sequence_output, pooled_output = model(
result = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
sequence_output, pooled_output = model(
result = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
encoder_hidden_states=encoder_hidden_states,
)
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_bert_for_causal_lm(
self,
......@@ -227,16 +218,8 @@ class BertModelTester:
model = BertLMHeadModel(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
self.check_loss_output(result)
def create_and_check_bert_for_masked_lm(
......@@ -245,16 +228,8 @@ class BertModelTester:
model = BertForMaskedLM(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
self.check_loss_output(result)
def create_and_check_bert_model_for_causal_lm_as_decoder(
......@@ -272,7 +247,7 @@ class BertModelTester:
model = BertLMHeadModel(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(
result = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
......@@ -280,20 +255,14 @@ class BertModelTester:
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
loss, prediction_scores = model(
result = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
labels=token_labels,
encoder_hidden_states=encoder_hidden_states,
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
self.check_loss_output(result)
def create_and_check_bert_for_next_sequence_prediction(
......@@ -302,14 +271,10 @@ class BertModelTester:
model = BertForNextSentencePrediction(config=config)
model.to(torch_device)
model.eval()
loss, seq_relationship_score = model(
result = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, next_sentence_label=sequence_labels,
)
result = {
"loss": loss,
"seq_relationship_score": seq_relationship_score,
}
self.parent.assertListEqual(list(result["seq_relationship_score"].size()), [self.batch_size, 2])
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, 2])
self.check_loss_output(result)
def create_and_check_bert_for_pretraining(
......@@ -318,22 +283,17 @@ class BertModelTester:
model = BertForPreTraining(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores, seq_relationship_score = model(
result = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
labels=token_labels,
next_sentence_label=sequence_labels,
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
"seq_relationship_score": seq_relationship_score,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
list(result["prediction_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.parent.assertListEqual(list(result["seq_relationship_score"].size()), [self.batch_size, 2])
self.parent.assertListEqual(list(result["seq_relationship_logits"].size()), [self.batch_size, 2])
self.check_loss_output(result)
def create_and_check_bert_for_question_answering(
......@@ -342,18 +302,13 @@ class BertModelTester:
model = BertForQuestionAnswering(config=config)
model.to(torch_device)
model.eval()
loss, start_logits, end_logits = model(
result = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result)
......@@ -365,13 +320,7 @@ class BertModelTester:
model = BertForSequenceClassification(config)
model.to(torch_device)
model.eval()
loss, logits = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
)
result = {
"loss": loss,
"logits": logits,
}
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
self.check_loss_output(result)
......@@ -382,11 +331,7 @@ class BertModelTester:
model = BertForTokenClassification(config=config)
model.to(torch_device)
model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result)
......@@ -400,16 +345,12 @@ class BertModelTester:
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model(
result = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels,
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
self.check_loss_output(result)
......
......@@ -28,13 +28,13 @@ if is_torch_available():
class CamembertModelIntegrationTest(unittest.TestCase):
@slow
def test_output_embeds_base_model(self):
model = CamembertModel.from_pretrained("camembert-base")
model = CamembertModel.from_pretrained("camembert-base", return_dict=True)
model.to(torch_device)
input_ids = torch.tensor(
[[5, 121, 11, 660, 16, 730, 25543, 110, 83, 6]], device=torch_device, dtype=torch.long,
) # J'aime le camembert !
output = model(input_ids)[0]
output = model(input_ids)["last_hidden_state"]
expected_shape = torch.Size((1, 10, 768))
self.assertEqual(output.shape, expected_shape)
# compare the actual values for a slice.
......
......@@ -74,7 +74,6 @@ class ModelTesterMixin:
def test_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
for model_class in self.all_model_classes:
model = model_class(config)
......
......@@ -88,9 +88,10 @@ class CTRLModelTester:
# hidden_dropout_prob=self.hidden_dropout_prob,
# attention_probs_dropout_prob=self.attention_probs_dropout_prob,
n_positions=self.max_position_embeddings,
n_ctx=self.max_position_embeddings
n_ctx=self.max_position_embeddings,
# type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range
# initializer_range=self.initializer_range,
return_dict=True,
)
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
......@@ -117,29 +118,20 @@ class CTRLModelTester:
model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
model(input_ids, token_type_ids=token_type_ids)
sequence_output, presents = model(input_ids)
result = {
"sequence_output": sequence_output,
"presents": presents,
}
result = model(input_ids)
self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertEqual(len(result["presents"]), config.n_layer)
self.parent.assertEqual(len(result["past_key_values"]), config.n_layer)
def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = CTRLLMHeadModel(config)
model.to(torch_device)
model.eval()
loss, lm_logits, _ = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
result = {"loss": loss, "lm_logits": lm_logits}
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual(
list(result["lm_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
......
......@@ -110,6 +110,7 @@ if is_torch_available():
attention_dropout=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
initializer_range=self.initializer_range,
return_dict=True,
)
return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
......@@ -123,14 +124,10 @@ if is_torch_available():
model = DistilBertModel(config=config)
model.to(torch_device)
model.eval()
(sequence_output,) = model(input_ids, input_mask)
(sequence_output,) = model(input_ids)
result = {
"sequence_output": sequence_output,
}
result = model(input_ids, input_mask)
result = model(input_ids)
self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
)
def create_and_check_distilbert_for_masked_lm(
......@@ -139,13 +136,9 @@ if is_torch_available():
model = DistilBertForMaskedLM(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=token_labels)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.check_loss_output(result)
......@@ -155,14 +148,9 @@ if is_torch_available():
model = DistilBertForQuestionAnswering(config=config)
model.to(torch_device)
model.eval()
loss, start_logits, end_logits = model(
result = model(
input_ids, attention_mask=input_mask, start_positions=sequence_labels, end_positions=sequence_labels
)
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result)
......@@ -174,11 +162,7 @@ if is_torch_available():
model = DistilBertForSequenceClassification(config)
model.to(torch_device)
model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
result = {
"loss": loss,
"logits": logits,
}
result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
self.check_loss_output(result)
......@@ -190,11 +174,7 @@ if is_torch_available():
model.to(torch_device)
model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertListEqual(
list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]
)
......@@ -209,13 +189,9 @@ if is_torch_available():
model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model(
result = model(
multiple_choice_inputs_ids, attention_mask=multiple_choice_input_mask, labels=choice_labels,
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
self.check_loss_output(result)
......
......@@ -115,6 +115,7 @@ class DPRModelTester:
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
return_dict=True,
)
config = DPRConfig(projection_dim=self.projection_dim, **config.to_dict())
......@@ -126,15 +127,11 @@ class DPRModelTester:
model = DPRContextEncoder(config=config)
model.to(torch_device)
model.eval()
embeddings = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)[0]
embeddings = model(input_ids, token_type_ids=token_type_ids)[0]
embeddings = model(input_ids)[0]
result = {
"embeddings": embeddings,
}
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids)
self.parent.assertListEqual(
list(result["embeddings"].size()), [self.batch_size, self.projection_dim or self.hidden_size]
list(result["pooler_output"].size()), [self.batch_size, self.projection_dim or self.hidden_size]
)
def create_and_check_dpr_question_encoder(
......@@ -143,15 +140,11 @@ class DPRModelTester:
model = DPRQuestionEncoder(config=config)
model.to(torch_device)
model.eval()
embeddings = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)[0]
embeddings = model(input_ids, token_type_ids=token_type_ids)[0]
embeddings = model(input_ids)[0]
result = {
"embeddings": embeddings,
}
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids)
self.parent.assertListEqual(
list(result["embeddings"].size()), [self.batch_size, self.projection_dim or self.hidden_size]
list(result["pooler_output"].size()), [self.batch_size, self.projection_dim or self.hidden_size]
)
def create_and_check_dpr_reader(
......@@ -160,12 +153,7 @@ class DPRModelTester:
model = DPRReader(config=config)
model.to(torch_device)
model.eval()
start_logits, end_logits, relevance_logits, *_ = model(input_ids, attention_mask=input_mask,)
result = {
"relevance_logits": relevance_logits,
"start_logits": start_logits,
"end_logits": end_logits,
}
result = model(input_ids, attention_mask=input_mask,)
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["relevance_logits"].size()), [self.batch_size])
......
......@@ -97,6 +97,7 @@ class ElectraModelTester:
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
return_dict=True,
)
return (
......@@ -127,15 +128,11 @@ class ElectraModelTester:
model = ElectraModel(config=config)
model.to(torch_device)
model.eval()
(sequence_output,) = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
(sequence_output,) = model(input_ids, token_type_ids=token_type_ids)
(sequence_output,) = model(input_ids)
result = {
"sequence_output": sequence_output,
}
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids)
self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
)
def create_and_check_electra_for_masked_lm(
......@@ -152,16 +149,8 @@ class ElectraModelTester:
model = ElectraForMaskedLM(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
self.check_loss_output(result)
def create_and_check_electra_for_token_classification(
......@@ -179,11 +168,7 @@ class ElectraModelTester:
model = ElectraForTokenClassification(config=config)
model.to(torch_device)
model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result)
......@@ -202,13 +187,7 @@ class ElectraModelTester:
model = ElectraForPreTraining(config=config)
model.to(torch_device)
model.eval()
loss, logits = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=fake_token_labels
)
result = {
"loss": loss,
"logits": logits,
}
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=fake_token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result)
......@@ -227,13 +206,7 @@ class ElectraModelTester:
model = ElectraForSequenceClassification(config)
model.to(torch_device)
model.eval()
loss, logits = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
)
result = {
"loss": loss,
"logits": logits,
}
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
self.check_loss_output(result)
......@@ -251,18 +224,13 @@ class ElectraModelTester:
model = ElectraForQuestionAnswering(config=config)
model.to(torch_device)
model.eval()
loss, start_logits, end_logits = model(
result = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result)
......@@ -285,16 +253,12 @@ class ElectraModelTester:
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model(
result = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels,
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
self.check_loss_output(result)
......
......@@ -110,6 +110,7 @@ class FlaubertModelTester(object):
initializer_range=self.initializer_range,
summary_type=self.summary_type,
use_proj=self.use_proj,
return_dict=True,
)
return (
......@@ -142,15 +143,11 @@ class FlaubertModelTester(object):
model = FlaubertModel(config=config)
model.to(torch_device)
model.eval()
outputs = model(input_ids, lengths=input_lengths, langs=token_type_ids)
outputs = model(input_ids, langs=token_type_ids)
outputs = model(input_ids)
sequence_output = outputs[0]
result = {
"sequence_output": sequence_output,
}
result = model(input_ids, lengths=input_lengths, langs=token_type_ids)
result = model(input_ids, langs=token_type_ids)
result = model(input_ids)
self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
)
def create_and_check_flaubert_lm_head(
......@@ -169,13 +166,7 @@ class FlaubertModelTester(object):
model.to(torch_device)
model.eval()
loss, logits = model(input_ids, token_type_ids=token_type_ids, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
result = model(input_ids, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
......@@ -195,16 +186,9 @@ class FlaubertModelTester(object):
model.to(torch_device)
model.eval()
outputs = model(input_ids)
outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
loss, start_logits, end_logits = outputs
result = model(input_ids)
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
result = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result)
......@@ -225,10 +209,9 @@ class FlaubertModelTester(object):
model.to(torch_device)
model.eval()
outputs = model(input_ids)
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = outputs
result = model(input_ids)
outputs = model(
result_with_labels = model(
input_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
......@@ -237,7 +220,7 @@ class FlaubertModelTester(object):
p_mask=input_mask,
)
outputs = model(
result_with_labels = model(
input_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
......@@ -245,22 +228,13 @@ class FlaubertModelTester(object):
is_impossible=is_impossible_labels,
)
(total_loss,) = outputs
(total_loss,) = result_with_labels.to_tuple()
outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
result_with_labels = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
(total_loss,) = outputs
(total_loss,) = result_with_labels.to_tuple()
result = {
"loss": total_loss,
"start_top_log_probs": start_top_log_probs,
"start_top_index": start_top_index,
"end_top_log_probs": end_top_log_probs,
"end_top_index": end_top_index,
"cls_logits": cls_logits,
}
self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual(list(result_with_labels["loss"].size()), [])
self.parent.assertListEqual(
list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top]
)
......@@ -292,13 +266,8 @@ class FlaubertModelTester(object):
model.to(torch_device)
model.eval()
(logits,) = model(input_ids)
loss, logits = model(input_ids, labels=sequence_labels)
result = {
"loss": loss,
"logits": logits,
}
result = model(input_ids)
result = model(input_ids, labels=sequence_labels)
self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size])
......@@ -320,11 +289,7 @@ class FlaubertModelTester(object):
model.to(torch_device)
model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result)
......@@ -347,16 +312,12 @@ class FlaubertModelTester(object):
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model(
result = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels,
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
self.check_loss_output(result)
......
......@@ -122,9 +122,10 @@ class GPT2ModelTester:
n_positions=self.max_position_embeddings,
n_ctx=self.max_position_embeddings,
# type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range
# initializer_range=self.initializer_range,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
return_dict=True,
)
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
......@@ -149,18 +150,14 @@ class GPT2ModelTester:
model.to(torch_device)
model.eval()
model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
model(input_ids, token_type_ids=token_type_ids)
sequence_output, presents = model(input_ids)
result = model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids)
result = {
"sequence_output": sequence_output,
"presents": presents,
}
self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size],
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size],
)
self.parent.assertEqual(len(result["presents"]), config.n_layer)
self.parent.assertEqual(len(result["past_key_values"]), config.n_layer)
def create_and_check_gpt2_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = GPT2Model(config=config)
......@@ -175,7 +172,7 @@ class GPT2ModelTester:
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
output, past = outputs
output, past = outputs.to_tuple()
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
......@@ -185,8 +182,8 @@ class GPT2ModelTester:
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
next_token_type_ids = torch.cat([token_type_ids, next_token_types], dim=-1)
output_from_no_past, _ = model(next_input_ids, token_type_ids=next_token_type_ids)
output_from_past, _ = model(next_tokens, token_type_ids=next_token_types, past=past)
output_from_no_past = model(next_input_ids, token_type_ids=next_token_type_ids)["last_hidden_state"]
output_from_past = model(next_tokens, token_type_ids=next_token_types, past=past)["last_hidden_state"]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
......@@ -209,7 +206,7 @@ class GPT2ModelTester:
attn_mask[:, half_seq_length:] = 0
# first forward pass
output, past = model(input_ids, attention_mask=attn_mask)
output, past = model(input_ids, attention_mask=attn_mask).to_tuple()
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
......@@ -226,8 +223,8 @@ class GPT2ModelTester:
)
# get two different outputs
output_from_no_past, _ = model(next_input_ids, attention_mask=attn_mask)
output_from_past, _ = model(next_tokens, past=past, attention_mask=attn_mask)
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
output_from_past = model(next_tokens, past=past, attention_mask=attn_mask)["last_hidden_state"]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
......@@ -242,13 +239,10 @@ class GPT2ModelTester:
model.to(torch_device)
model.eval()
loss, lm_logits, _ = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
result = {"loss": loss, "lm_logits": lm_logits}
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual(
list(result["lm_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
def create_and_check_double_lm_head_model(
......@@ -270,11 +264,8 @@ class GPT2ModelTester:
"labels": multiple_choice_inputs_ids,
}
loss, lm_logits, mc_logits, _ = model(**inputs)
result = {"loss": loss, "lm_logits": lm_logits, "mc_logits": mc_logits}
self.parent.assertListEqual(list(result["loss"].size()), [])
result = model(**inputs)
self.parent.assertListEqual(list(result["lm_loss"].size()), [])
self.parent.assertListEqual(
list(result["lm_logits"].size()), [self.batch_size, self.num_choices, self.seq_length, self.vocab_size],
)
......
......@@ -108,6 +108,7 @@ class LongformerModelTester:
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
attention_window=self.attention_window,
return_dict=True,
)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
......@@ -123,8 +124,8 @@ class LongformerModelTester:
model.eval()
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
output_with_mask = model(input_ids, attention_mask=attention_mask)[0]
output_without_mask = model(input_ids)[0]
output_with_mask = model(input_ids, attention_mask=attention_mask)["last_hidden_state"]
output_without_mask = model(input_ids)["last_hidden_state"]
self.parent.assertTrue(torch.allclose(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], atol=1e-4))
def create_and_check_longformer_model(
......@@ -133,18 +134,13 @@ class LongformerModelTester:
model = LongformerModel(config=config)
model.to(torch_device)
model.eval()
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids)
self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_longformer_model_with_global_attention_mask(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
......@@ -156,25 +152,19 @@ class LongformerModelTester:
global_attention_mask[:, input_mask.shape[-1] // 2] = 0
global_attention_mask = global_attention_mask.to(torch_device)
sequence_output, pooled_output = model(
result = model(
input_ids,
attention_mask=input_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
)
sequence_output, pooled_output = model(
input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask
)
sequence_output, pooled_output = model(input_ids, global_attention_mask=global_attention_mask)
result = model(input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask)
result = model(input_ids, global_attention_mask=global_attention_mask)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_longformer_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
......@@ -182,16 +172,8 @@ class LongformerModelTester:
model = LongformerForMaskedLM(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
self.check_loss_output(result)
def create_and_check_longformer_for_question_answering(
......@@ -200,7 +182,7 @@ class LongformerModelTester:
model = LongformerForQuestionAnswering(config=config)
model.to(torch_device)
model.eval()
loss, start_logits, end_logits = model(
result = model(
input_ids,
attention_mask=input_mask,
global_attention_mask=input_mask,
......@@ -208,11 +190,6 @@ class LongformerModelTester:
start_positions=sequence_labels,
end_positions=sequence_labels,
)
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result)
......@@ -224,13 +201,7 @@ class LongformerModelTester:
model = LongformerForSequenceClassification(config)
model.to(torch_device)
model.eval()
loss, logits = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
)
result = {
"loss": loss,
"logits": logits,
}
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
self.check_loss_output(result)
......@@ -241,11 +212,7 @@ class LongformerModelTester:
model = LongformerForTokenClassification(config=config)
model.to(torch_device)
model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result)
......@@ -260,17 +227,13 @@ class LongformerModelTester:
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model(
result = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
global_attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels,
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
self.check_loss_output(result)
......
......@@ -114,13 +114,14 @@ class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
decoder_ffn_dim=32,
max_position_embeddings=48,
add_final_layer_norm=True,
return_dict=True,
)
lm_model = BartForConditionalGeneration(config).to(torch_device)
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
loss, logits, enc_features = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
result = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
expected_shape = (*summary.shape, config.vocab_size)
self.assertEqual(logits.shape, expected_shape)
self.assertEqual(result["logits"].shape, expected_shape)
@require_torch
......
......@@ -122,6 +122,7 @@ class MobileBertModelTester:
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
return_dict=True,
)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
......@@ -162,18 +163,14 @@ class MobileBertModelTester:
model = MobileBertModel(config=config)
model.to(torch_device)
model.eval()
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids)
self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_mobilebert_model_as_decoder(
self,
......@@ -190,29 +187,25 @@ class MobileBertModelTester:
model = MobileBertModel(config)
model.to(torch_device)
model.eval()
sequence_output, pooled_output = model(
result = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
sequence_output, pooled_output = model(
result = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
encoder_hidden_states=encoder_hidden_states,
)
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_mobilebert_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
......@@ -220,16 +213,8 @@ class MobileBertModelTester:
model = MobileBertForMaskedLM(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
self.check_loss_output(result)
def create_and_check_mobilebert_for_next_sequence_prediction(
......@@ -238,14 +223,10 @@ class MobileBertModelTester:
model = MobileBertForNextSentencePrediction(config=config)
model.to(torch_device)
model.eval()
loss, seq_relationship_score = model(
result = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, next_sentence_label=sequence_labels,
)
result = {
"loss": loss,
"seq_relationship_score": seq_relationship_score,
}
self.parent.assertListEqual(list(result["seq_relationship_score"].size()), [self.batch_size, 2])
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, 2])
self.check_loss_output(result)
def create_and_check_mobilebert_for_pretraining(
......@@ -254,22 +235,17 @@ class MobileBertModelTester:
model = MobileBertForPreTraining(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores, seq_relationship_score = model(
result = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
labels=token_labels,
next_sentence_label=sequence_labels,
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
"seq_relationship_score": seq_relationship_score,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
list(result["prediction_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.parent.assertListEqual(list(result["seq_relationship_score"].size()), [self.batch_size, 2])
self.parent.assertListEqual(list(result["seq_relationship_logits"].size()), [self.batch_size, 2])
self.check_loss_output(result)
def create_and_check_mobilebert_for_question_answering(
......@@ -278,18 +254,13 @@ class MobileBertModelTester:
model = MobileBertForQuestionAnswering(config=config)
model.to(torch_device)
model.eval()
loss, start_logits, end_logits = model(
result = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result)
......@@ -301,13 +272,7 @@ class MobileBertModelTester:
model = MobileBertForSequenceClassification(config)
model.to(torch_device)
model.eval()
loss, logits = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
)
result = {
"loss": loss,
"logits": logits,
}
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
self.check_loss_output(result)
......@@ -318,11 +283,7 @@ class MobileBertModelTester:
model = MobileBertForTokenClassification(config=config)
model.to(torch_device)
model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result)
......@@ -336,16 +297,12 @@ class MobileBertModelTester:
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model(
result = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels,
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
self.check_loss_output(result)
......
......@@ -85,9 +85,10 @@ class OpenAIGPTModelTester:
# hidden_dropout_prob=self.hidden_dropout_prob,
# attention_probs_dropout_prob=self.attention_probs_dropout_prob,
n_positions=self.max_position_embeddings,
n_ctx=self.max_position_embeddings
n_ctx=self.max_position_embeddings,
# type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range
return_dict=True,
)
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
......@@ -110,13 +111,12 @@ class OpenAIGPTModelTester:
model.to(torch_device)
model.eval()
model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
model(input_ids, token_type_ids=token_type_ids)
(sequence_output,) = model(input_ids)
result = model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids)
result = {"sequence_output": sequence_output}
self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size],
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size],
)
def create_and_check_lm_head_model(self, config, input_ids, head_mask, token_type_ids, *args):
......@@ -124,13 +124,10 @@ class OpenAIGPTModelTester:
model.to(torch_device)
model.eval()
loss, lm_logits = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
result = {"loss": loss, "lm_logits": lm_logits}
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual(
list(result["lm_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
def create_and_check_double_lm_head_model(self, config, input_ids, head_mask, token_type_ids, *args):
......@@ -138,11 +135,8 @@ class OpenAIGPTModelTester:
model.to(torch_device)
model.eval()
loss, lm_logits, mc_logits = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
result = {"loss": loss, "lm_logits": lm_logits}
self.parent.assertListEqual(list(result["loss"].size()), [])
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
self.parent.assertListEqual(list(result["lm_loss"].size()), [])
self.parent.assertListEqual(
list(result["lm_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
......
......@@ -165,6 +165,7 @@ class ReformerModelTester:
attn_layers=self.attn_layers,
pad_token_id=self.pad_token_id,
hash_seed=self.hash_seed,
return_dict=True,
)
return (
......@@ -181,15 +182,12 @@ class ReformerModelTester:
model = ReformerModel(config=config)
model.to(torch_device)
model.eval()
sequence_output, _ = model(input_ids, attention_mask=input_mask)
sequence_output, _ = model(input_ids)
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
result = {
"sequence_output": sequence_output,
}
# 2 * hidden_size because we use reversible resnet layers
self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size],
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size],
)
def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels):
......@@ -198,7 +196,7 @@ class ReformerModelTester:
model = ReformerForMaskedLM(config=config)
model.to(torch_device)
model.eval()
loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0]
loss = model(input_ids, attention_mask=input_mask, labels=input_ids)["loss"]
loss.backward()
def create_and_check_reformer_with_lm(self, config, input_ids, input_mask, choice_labels):
......@@ -207,13 +205,9 @@ class ReformerModelTester:
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores, _ = model(input_ids, attention_mask=input_mask, labels=input_ids)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
result = model(input_ids, attention_mask=input_mask, labels=input_ids)
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size],
list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
self.check_loss_output(result)
......@@ -222,13 +216,9 @@ class ReformerModelTester:
model = ReformerForMaskedLM(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
result = model(input_ids, attention_mask=input_mask, labels=input_ids)
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size],
list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
self.check_loss_output(result)
......@@ -325,7 +315,7 @@ class ReformerModelTester:
model.to(torch_device)
model.eval()
hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)[0]
hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)["last_hidden_state"]
self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask, choice_labels):
......@@ -408,7 +398,7 @@ class ReformerModelTester:
model.to(torch_device)
model.half()
model.eval()
output = model(input_ids, attention_mask=input_mask)[0]
output = model(input_ids, attention_mask=input_mask)["last_input_state"]
self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_reformer_model_generate(self, config, input_ids, input_mask, choice_labels):
......@@ -444,21 +434,16 @@ class ReformerModelTester:
model = ReformerForMaskedLM(config=config)
model.to(torch_device)
model.eval()
output_logits = model(input_ids, attention_mask=input_mask)[0]
output_logits = model(input_ids, attention_mask=input_mask)["logits"]
self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1])
def create_and_check_reformer_for_question_answering(self, config, input_ids, input_mask, choice_labels):
model = ReformerForQuestionAnswering(config=config)
model.to(torch_device)
model.eval()
loss, start_logits, end_logits = model(
result = model(
input_ids, attention_mask=input_mask, start_positions=choice_labels, end_positions=choice_labels,
)
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result)
......@@ -474,11 +459,11 @@ class ReformerModelTester:
input_ids_second = input_ids[:, -1:]
# return saved cache
_, past_buckets_states = model(input_ids_first, use_cache=True)
past_buckets_states = model(input_ids_first, use_cache=True)["past_buckets_states"]
# calculate last output with and without cache
outputs_with_cache, _ = model(input_ids_second, past_buckets_states=past_buckets_states, use_cache=True)
outputs_without_cache = model(input_ids)[0][:, -1]
outputs_with_cache = model(input_ids_second, past_buckets_states=past_buckets_states, use_cache=True)["logits"]
outputs_without_cache = model(input_ids)["logits"][:, -1]
# select random slice idx
random_slice_idx = torch.randint(outputs_without_cache.shape[-1], (1, 1), device=torch_device).item()
......@@ -504,11 +489,7 @@ class ReformerModelTester:
model = ReformerForSequenceClassification(config)
model.to(torch_device)
model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
result = {
"loss": loss,
"logits": logits,
}
result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
self.check_loss_output(result)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment