Unverified Commit d951c14a authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Model output test (#6155)

* Use return_dict=True in all tests

* Formatting
parent 86caab1e
...@@ -96,6 +96,7 @@ class RobertaModelTester: ...@@ -96,6 +96,7 @@ class RobertaModelTester:
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
return_dict=True,
) )
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -109,18 +110,14 @@ class RobertaModelTester: ...@@ -109,18 +110,14 @@ class RobertaModelTester:
model = RobertaModel(config=config) model = RobertaModel(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids) result = model(input_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size] list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
) )
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_roberta_for_masked_lm( def create_and_check_roberta_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
...@@ -128,16 +125,8 @@ class RobertaModelTester: ...@@ -128,16 +125,8 @@ class RobertaModelTester:
model = RobertaForMaskedLM(config=config) model = RobertaForMaskedLM(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, prediction_scores = model( result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_roberta_for_token_classification( def create_and_check_roberta_for_token_classification(
...@@ -147,11 +136,7 @@ class RobertaModelTester: ...@@ -147,11 +136,7 @@ class RobertaModelTester:
model = RobertaForTokenClassification(config=config) model = RobertaForTokenClassification(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -165,16 +150,12 @@ class RobertaModelTester: ...@@ -165,16 +150,12 @@ class RobertaModelTester:
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model( result = model(
multiple_choice_inputs_ids, multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask, attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids, token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels, labels=choice_labels,
) )
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -184,18 +165,13 @@ class RobertaModelTester: ...@@ -184,18 +165,13 @@ class RobertaModelTester:
model = RobertaForQuestionAnswering(config=config) model = RobertaForQuestionAnswering(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, start_logits, end_logits = model( result = model(
input_ids, input_ids,
attention_mask=input_mask, attention_mask=input_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
start_positions=sequence_labels, start_positions=sequence_labels,
end_positions=sequence_labels, end_positions=sequence_labels,
) )
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result) self.check_loss_output(result)
......
...@@ -83,6 +83,7 @@ class T5ModelTester: ...@@ -83,6 +83,7 @@ class T5ModelTester:
bos_token_id=self.pad_token_id, bos_token_id=self.pad_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
decoder_start_token_id=self.decoder_start_token_id, decoder_start_token_id=self.decoder_start_token_id,
return_dict=True,
) )
return ( return (
...@@ -136,13 +137,17 @@ class T5ModelTester: ...@@ -136,13 +137,17 @@ class T5ModelTester:
model = T5Model(config=config) model = T5Model(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
decoder_output, decoder_past, encoder_output = model( result = model(
input_ids=input_ids, input_ids=input_ids,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
) )
decoder_output, decoder_past, encoder_output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
decoder_output = result["last_hidden_state"]
decoder_past = result["decoder_past_key_values"]
encoder_output = result["encoder_last_hidden_state"]
self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size)) self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size))
self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.decoder_seq_length, self.hidden_size)) self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.decoder_seq_length, self.hidden_size))
self.parent.assertEqual(len(decoder_past), 2) self.parent.assertEqual(len(decoder_past), 2)
...@@ -162,10 +167,9 @@ class T5ModelTester: ...@@ -162,10 +167,9 @@ class T5ModelTester:
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
labels=lm_labels, labels=lm_labels,
) )
loss, prediction_scores, _, _ = outputs
self.parent.assertEqual(len(outputs), 4) self.parent.assertEqual(len(outputs), 4)
self.parent.assertEqual(prediction_scores.size(), (self.batch_size, self.decoder_seq_length, self.vocab_size)) self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size))
self.parent.assertEqual(loss.size(), ()) self.parent.assertEqual(outputs["loss"].size(), ())
def create_and_check_t5_decoder_model_past( def create_and_check_t5_decoder_model_past(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
...@@ -179,7 +183,7 @@ class T5ModelTester: ...@@ -179,7 +183,7 @@ class T5ModelTester:
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
output, past_key_value_states = outputs output, past_key_value_states = outputs.to_tuple()
# create hypothetical next token and extent to next_input_ids # create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
...@@ -187,8 +191,8 @@ class T5ModelTester: ...@@ -187,8 +191,8 @@ class T5ModelTester:
# append to next input_ids and # append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
output_from_no_past = model(next_input_ids)[0] output_from_no_past = model(next_input_ids)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_value_states=past_key_value_states)[0] output_from_past = model(next_tokens, past_key_value_states=past_key_value_states)["last_hidden_state"]
# select random slice # select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
...@@ -212,7 +216,7 @@ class T5ModelTester: ...@@ -212,7 +216,7 @@ class T5ModelTester:
attn_mask[:, half_seq_length:] = 0 attn_mask[:, half_seq_length:] = 0
# first forward pass # first forward pass
output, past_key_value_states = model(input_ids, attention_mask=attn_mask, use_cache=True) output, past_key_value_states = model(input_ids, attention_mask=attn_mask, use_cache=True).to_tuple()
# create hypothetical next token and extent to next_input_ids # create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
...@@ -229,8 +233,10 @@ class T5ModelTester: ...@@ -229,8 +233,10 @@ class T5ModelTester:
) )
# get two different outputs # get two different outputs
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)[0] output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_value_states=past_key_value_states, attention_mask=attn_mask)[0] output_from_past = model(next_tokens, past_key_value_states=past_key_value_states, attention_mask=attn_mask)[
"last_hidden_state"
]
# select random slice # select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
...@@ -256,7 +262,7 @@ class T5ModelTester: ...@@ -256,7 +262,7 @@ class T5ModelTester:
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
): ):
model = T5Model(config=config).to(torch_device).half().eval() model = T5Model(config=config).to(torch_device).half().eval()
output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)[0] output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"]
self.parent.assertFalse(torch.isnan(output).any().item()) self.parent.assertFalse(torch.isnan(output).any().item())
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
......
...@@ -75,6 +75,7 @@ class TransfoXLModelTester: ...@@ -75,6 +75,7 @@ class TransfoXLModelTester:
div_val=self.div_val, div_val=self.div_val,
n_layer=self.num_hidden_layers, n_layer=self.num_hidden_layers,
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
return_dict=True,
) )
return (config, input_ids_1, input_ids_2, lm_labels) return (config, input_ids_1, input_ids_2, lm_labels)
...@@ -88,13 +89,13 @@ class TransfoXLModelTester: ...@@ -88,13 +89,13 @@ class TransfoXLModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
hidden_states_1, mems_1 = model(input_ids_1) outputs1 = model(input_ids_1)
hidden_states_2, mems_2 = model(input_ids_2, mems_1) outputs2 = model(input_ids_2, outputs1["mems"])
outputs = { outputs = {
"hidden_states_1": hidden_states_1, "hidden_states_1": outputs1["last_hidden_state"],
"mems_1": mems_1, "mems_1": outputs1["mems"],
"hidden_states_2": hidden_states_2, "hidden_states_2": outputs2["last_hidden_state"],
"mems_2": mems_2, "mems_2": outputs2["mems"],
} }
return outputs return outputs
...@@ -119,17 +120,17 @@ class TransfoXLModelTester: ...@@ -119,17 +120,17 @@ class TransfoXLModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
lm_logits_1, mems_1 = model(input_ids_1) lm_logits_1 = model(input_ids_1)["prediction_scores"]
loss_1, _, mems_1 = model(input_ids_1, labels=lm_labels) outputs1 = model(input_ids_1, labels=lm_labels)
lm_logits_2, mems_2 = model(input_ids_2, mems=mems_1) lm_logits_2 = model(input_ids_2, mems=outputs1["mems"])["prediction_scores"]
loss_2, _, mems_2 = model(input_ids_2, labels=lm_labels, mems=mems_1) outputs2 = model(input_ids_2, labels=lm_labels, mems=outputs1["mems"])
outputs = { outputs = {
"loss_1": loss_1, "loss_1": outputs1["losses"],
"mems_1": mems_1, "mems_1": outputs1["mems"],
"lm_logits_1": lm_logits_1, "lm_logits_1": lm_logits_1,
"loss_2": loss_2, "loss_2": outputs2["losses"],
"mems_2": mems_2, "mems_2": outputs2["mems"],
"lm_logits_2": lm_logits_2, "lm_logits_2": lm_logits_2,
} }
return outputs return outputs
......
...@@ -113,6 +113,7 @@ class XLMModelTester: ...@@ -113,6 +113,7 @@ class XLMModelTester:
use_proj=self.use_proj, use_proj=self.use_proj,
num_labels=self.num_labels, num_labels=self.num_labels,
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
return_dict=True,
) )
return ( return (
...@@ -145,15 +146,11 @@ class XLMModelTester: ...@@ -145,15 +146,11 @@ class XLMModelTester:
model = XLMModel(config=config) model = XLMModel(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
outputs = model(input_ids, lengths=input_lengths, langs=token_type_ids) result = model(input_ids, lengths=input_lengths, langs=token_type_ids)
outputs = model(input_ids, langs=token_type_ids) result = model(input_ids, langs=token_type_ids)
outputs = model(input_ids) result = model(input_ids)
sequence_output = outputs[0]
result = {
"sequence_output": sequence_output,
}
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size] list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
) )
def create_and_check_xlm_lm_head( def create_and_check_xlm_lm_head(
...@@ -172,13 +169,7 @@ class XLMModelTester: ...@@ -172,13 +169,7 @@ class XLMModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, logits = model(input_ids, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, token_type_ids=token_type_ids, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
...@@ -201,13 +192,7 @@ class XLMModelTester: ...@@ -201,13 +192,7 @@ class XLMModelTester:
outputs = model(input_ids) outputs = model(input_ids)
outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels) outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
loss, start_logits, end_logits = outputs result = outputs
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -228,10 +213,9 @@ class XLMModelTester: ...@@ -228,10 +213,9 @@ class XLMModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
outputs = model(input_ids) result = model(input_ids)
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = outputs
outputs = model( result_with_labels = model(
input_ids, input_ids,
start_positions=sequence_labels, start_positions=sequence_labels,
end_positions=sequence_labels, end_positions=sequence_labels,
...@@ -240,7 +224,7 @@ class XLMModelTester: ...@@ -240,7 +224,7 @@ class XLMModelTester:
p_mask=input_mask, p_mask=input_mask,
) )
outputs = model( result_with_labels = model(
input_ids, input_ids,
start_positions=sequence_labels, start_positions=sequence_labels,
end_positions=sequence_labels, end_positions=sequence_labels,
...@@ -248,22 +232,13 @@ class XLMModelTester: ...@@ -248,22 +232,13 @@ class XLMModelTester:
is_impossible=is_impossible_labels, is_impossible=is_impossible_labels,
) )
(total_loss,) = outputs (total_loss,) = result_with_labels.to_tuple()
outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels) result_with_labels = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
(total_loss,) = outputs
result = { (total_loss,) = result_with_labels.to_tuple()
"loss": total_loss,
"start_top_log_probs": start_top_log_probs,
"start_top_index": start_top_index,
"end_top_log_probs": end_top_log_probs,
"end_top_index": end_top_index,
"cls_logits": cls_logits,
}
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertListEqual(list(result_with_labels["loss"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top] list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top]
) )
...@@ -295,14 +270,8 @@ class XLMModelTester: ...@@ -295,14 +270,8 @@ class XLMModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
(logits,) = model(input_ids) result = model(input_ids)
loss, logits = model(input_ids, labels=sequence_labels) result = model(input_ids, labels=sequence_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size])
...@@ -323,11 +292,7 @@ class XLMModelTester: ...@@ -323,11 +292,7 @@ class XLMModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, labels=token_labels) result = model(input_ids, attention_mask=input_mask, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result) self.check_loss_output(result)
...@@ -350,16 +315,12 @@ class XLMModelTester: ...@@ -350,16 +315,12 @@ class XLMModelTester:
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model( result = model(
multiple_choice_inputs_ids, multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask, attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids, token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels, labels=choice_labels,
) )
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
self.check_loss_output(result) self.check_loss_output(result)
......
...@@ -28,7 +28,7 @@ if is_torch_available(): ...@@ -28,7 +28,7 @@ if is_torch_available():
class XLMRobertaModelIntegrationTest(unittest.TestCase): class XLMRobertaModelIntegrationTest(unittest.TestCase):
@slow @slow
def test_xlm_roberta_base(self): def test_xlm_roberta_base(self):
model = XLMRobertaModel.from_pretrained("xlm-roberta-base") model = XLMRobertaModel.from_pretrained("xlm-roberta-base", return_dict=True)
input_ids = torch.tensor([[0, 581, 10269, 83, 99942, 136, 60742, 23, 70, 80583, 18276, 2]]) input_ids = torch.tensor([[0, 581, 10269, 83, 99942, 136, 60742, 23, 70, 80583, 18276, 2]])
# The dog is cute and lives in the garden house # The dog is cute and lives in the garden house
...@@ -40,14 +40,14 @@ class XLMRobertaModelIntegrationTest(unittest.TestCase): ...@@ -40,14 +40,14 @@ class XLMRobertaModelIntegrationTest(unittest.TestCase):
# xlmr.eval() # xlmr.eval()
# expected_output_values_last_dim = xlmr.extract_features(input_ids[0])[:, :, -1] # expected_output_values_last_dim = xlmr.extract_features(input_ids[0])[:, :, -1]
output = model(input_ids)[0].detach() output = model(input_ids)["last_hidden_state"].detach()
self.assertEqual(output.shape, expected_output_shape) self.assertEqual(output.shape, expected_output_shape)
# compare the actual values for a slice of last dim # compare the actual values for a slice of last dim
self.assertTrue(torch.allclose(output[:, :, -1], expected_output_values_last_dim, atol=1e-3)) self.assertTrue(torch.allclose(output[:, :, -1], expected_output_values_last_dim, atol=1e-3))
@slow @slow
def test_xlm_roberta_large(self): def test_xlm_roberta_large(self):
model = XLMRobertaModel.from_pretrained("xlm-roberta-large") model = XLMRobertaModel.from_pretrained("xlm-roberta-large", return_dict=True)
input_ids = torch.tensor([[0, 581, 10269, 83, 99942, 136, 60742, 23, 70, 80583, 18276, 2]]) input_ids = torch.tensor([[0, 581, 10269, 83, 99942, 136, 60742, 23, 70, 80583, 18276, 2]])
# The dog is cute and lives in the garden house # The dog is cute and lives in the garden house
...@@ -59,7 +59,7 @@ class XLMRobertaModelIntegrationTest(unittest.TestCase): ...@@ -59,7 +59,7 @@ class XLMRobertaModelIntegrationTest(unittest.TestCase):
# xlmr.eval() # xlmr.eval()
# expected_output_values_last_dim = xlmr.extract_features(input_ids[0])[:, :, -1] # expected_output_values_last_dim = xlmr.extract_features(input_ids[0])[:, :, -1]
output = model(input_ids)[0].detach() output = model(input_ids)["last_hidden_state"].detach()
self.assertEqual(output.shape, expected_output_shape) self.assertEqual(output.shape, expected_output_shape)
# compare the actual values for a slice of last dim # compare the actual values for a slice of last dim
self.assertTrue(torch.allclose(output[:, :, -1], expected_output_values_last_dim, atol=1e-3)) self.assertTrue(torch.allclose(output[:, :, -1], expected_output_values_last_dim, atol=1e-3))
...@@ -137,6 +137,7 @@ class XLNetModelTester: ...@@ -137,6 +137,7 @@ class XLNetModelTester:
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
return_dict=True,
) )
return ( return (
...@@ -177,15 +178,10 @@ class XLNetModelTester: ...@@ -177,15 +178,10 @@ class XLNetModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
_, _ = model(input_ids_1, input_mask=input_mask) result = model(input_ids_1, input_mask=input_mask)
_, _ = model(input_ids_1, attention_mask=input_mask) result = model(input_ids_1, attention_mask=input_mask)
_, _ = model(input_ids_1, token_type_ids=segment_ids) result = model(input_ids_1, token_type_ids=segment_ids)
outputs, mems_1 = model(input_ids_1) result = model(input_ids_1)
result = {
"mems_1": mems_1,
"outputs": outputs,
}
config.mem_len = 0 config.mem_len = 0
model = XLNetModel(config) model = XLNetModel(config)
...@@ -195,10 +191,10 @@ class XLNetModelTester: ...@@ -195,10 +191,10 @@ class XLNetModelTester:
self.parent.assertEqual(len(base_model_output), 2) self.parent.assertEqual(len(base_model_output), 2)
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["outputs"].size()), [self.batch_size, self.seq_length, self.hidden_size], list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size],
) )
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]), list(list(mem.size()) for mem in result["mems"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers, [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
) )
...@@ -233,7 +229,7 @@ class XLNetModelTester: ...@@ -233,7 +229,7 @@ class XLNetModelTester:
self.parent.assertTrue(len(outputs_cache) == len(outputs_conf)) self.parent.assertTrue(len(outputs_cache) == len(outputs_conf))
self.parent.assertTrue(len(outputs_cache) == len(outputs_no_cache) + 1) self.parent.assertTrue(len(outputs_cache) == len(outputs_no_cache) + 1)
output, mems = outputs_cache output, mems = outputs_cache.to_tuple()
# create hypothetical next token and extent to next_input_ids # create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
...@@ -253,8 +249,8 @@ class XLNetModelTester: ...@@ -253,8 +249,8 @@ class XLNetModelTester:
single_mask = torch.ones(input_ids_1.shape[0], 1, 1, dtype=torch.float, device=torch_device) single_mask = torch.ones(input_ids_1.shape[0], 1, 1, dtype=torch.float, device=torch_device)
# second forward pass # second forward pass
output_from_no_past, _ = model(next_input_ids, perm_mask=causal_mask) output_from_no_past = model(next_input_ids, perm_mask=causal_mask)["last_hidden_state"]
output_from_past, _ = model(next_tokens, mems=mems, perm_mask=single_mask) output_from_past = model(next_tokens, mems=mems, perm_mask=single_mask)["last_hidden_state"]
# select random slice # select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
...@@ -283,7 +279,7 @@ class XLNetModelTester: ...@@ -283,7 +279,7 @@ class XLNetModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
_, _, attentions = model(input_ids_1, target_mapping=target_mapping, output_attentions=True) attentions = model(input_ids_1, target_mapping=target_mapping, output_attentions=True)["attentions"]
self.parent.assertEqual(len(attentions), config.n_layer) self.parent.assertEqual(len(attentions), config.n_layer)
self.parent.assertIsInstance(attentions[0], tuple) self.parent.assertIsInstance(attentions[0], tuple)
...@@ -309,36 +305,27 @@ class XLNetModelTester: ...@@ -309,36 +305,27 @@ class XLNetModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss_1, all_logits_1, mems_1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels) result1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels)
loss_2, all_logits_2, mems_2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1)
logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping) result2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=result1["mems"])
result = { _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping)
"loss_1": loss_1,
"mems_1": mems_1,
"all_logits_1": all_logits_1,
"loss_2": loss_2,
"mems_2": mems_2,
"all_logits_2": all_logits_2,
}
self.parent.assertListEqual(list(result["loss_1"].size()), []) self.parent.assertListEqual(list(result1["loss"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["all_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size], list(result1["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
) )
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]), list(list(mem.size()) for mem in result1["mems"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers, [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
) )
self.parent.assertListEqual(list(result["loss_2"].size()), []) self.parent.assertListEqual(list(result2["loss"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["all_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size], list(result2["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
) )
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2"]), list(list(mem.size()) for mem in result2["mems"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers, [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
) )
...@@ -361,10 +348,9 @@ class XLNetModelTester: ...@@ -361,10 +348,9 @@ class XLNetModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
outputs = model(input_ids_1) result = model(input_ids_1)
(start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems,) = outputs
outputs = model( result_with_labels = model(
input_ids_1, input_ids_1,
start_positions=sequence_labels, start_positions=sequence_labels,
end_positions=sequence_labels, end_positions=sequence_labels,
...@@ -373,7 +359,7 @@ class XLNetModelTester: ...@@ -373,7 +359,7 @@ class XLNetModelTester:
p_mask=input_mask, p_mask=input_mask,
) )
outputs = model( result_with_labels = model(
input_ids_1, input_ids_1,
start_positions=sequence_labels, start_positions=sequence_labels,
end_positions=sequence_labels, end_positions=sequence_labels,
...@@ -381,23 +367,13 @@ class XLNetModelTester: ...@@ -381,23 +367,13 @@ class XLNetModelTester:
is_impossible=is_impossible_labels, is_impossible=is_impossible_labels,
) )
total_loss, mems = outputs total_loss, mems = result_with_labels.to_tuple()
outputs = model(input_ids_1, start_positions=sequence_labels, end_positions=sequence_labels,) result_with_labels = model(input_ids_1, start_positions=sequence_labels, end_positions=sequence_labels,)
total_loss, mems = outputs total_loss, mems = result_with_labels.to_tuple()
result = { self.parent.assertListEqual(list(result_with_labels["loss"].size()), [])
"loss": total_loss,
"start_top_log_probs": start_top_log_probs,
"start_top_index": start_top_index,
"end_top_log_probs": end_top_log_probs,
"end_top_index": end_top_index,
"cls_logits": cls_logits,
"mems": mems,
}
self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top], list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top],
) )
...@@ -436,21 +412,15 @@ class XLNetModelTester: ...@@ -436,21 +412,15 @@ class XLNetModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
logits, mems_1 = model(input_ids_1) result = model(input_ids_1)
loss, logits, mems_1 = model(input_ids_1, labels=token_labels) result = model(input_ids_1, labels=token_labels)
result = {
"loss": loss,
"mems_1": mems_1,
"logits": logits,
}
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["logits"].size()), [self.batch_size, self.seq_length, self.type_sequence_label_size], list(result["logits"].size()), [self.batch_size, self.seq_length, self.type_sequence_label_size],
) )
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]), list(list(mem.size()) for mem in result["mems"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers, [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
) )
...@@ -473,21 +443,15 @@ class XLNetModelTester: ...@@ -473,21 +443,15 @@ class XLNetModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
logits, mems_1 = model(input_ids_1) result = model(input_ids_1)
loss, logits, mems_1 = model(input_ids_1, labels=sequence_labels) result = model(input_ids_1, labels=sequence_labels)
result = {
"loss": loss,
"mems_1": mems_1,
"logits": logits,
}
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size], list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size],
) )
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]), list(list(mem.size()) for mem in result["mems"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers, [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment