Commit e691fc09 authored by thomwolf's avatar thomwolf
Browse files

update QA models tests + run_generation

parent 15d8b126
...@@ -131,8 +131,10 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k= ...@@ -131,8 +131,10 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default=None, required=True, parser.add_argument("--model_type", default=None, type=str, required=True,
help="GPT, GPT-2, Transformer-XL or XLNet pre-trained model selected in the list: " + ", ".join(ALL_MODELS)) help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
parser.add_argument("--prompt", type=str, default="") parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--padding_text", type=str, default="") parser.add_argument("--padding_text", type=str, default="")
parser.add_argument("--length", type=int, default=20) parser.add_argument("--length", type=int, default=20)
...@@ -150,15 +152,10 @@ def main(): ...@@ -150,15 +152,10 @@ def main():
set_seed(args) set_seed(args)
args.model_type = "" args.model_type = args.model_type.lower()
for key in MODEL_CLASSES:
if key in args.model_name.lower():
args.model_type = key # take the first match in model types
break
model_class, tokenizer_class = MODEL_CLASSES[args.model_type] model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
tokenizer = tokenizer_class.from_pretrained(args.model_name) tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
model = model_class.from_pretrained(args.model_name) model = model_class.from_pretrained(args.model_name_or_path)
model.to(args.device) model.to(args.device)
model.eval() model.eval()
......
...@@ -101,7 +101,8 @@ class ExamplesTests(unittest.TestCase): ...@@ -101,7 +101,8 @@ class ExamplesTests(unittest.TestCase):
"--prompt=Hello", "--prompt=Hello",
"--length=10", "--length=10",
"--seed=42"] "--seed=42"]
model_name = "--model_name=openai-gpt" model_type, model_name = ("--model_type=openai-gpt",
"--model_name_or_path=openai-gpt")
with patch.object(sys, 'argv', testargs + [model_name]): with patch.object(sys, 'argv', testargs + [model_name]):
result = run_generation.main() result = run_generation.main()
self.assertGreaterEqual(len(result), 10) self.assertGreaterEqual(len(result), 10)
......
...@@ -191,17 +191,19 @@ class XLMModelTest(CommonTestCases.CommonModelTester): ...@@ -191,17 +191,19 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
cls_index=sequence_labels, cls_index=sequence_labels,
is_impossible=is_impossible_labels) is_impossible=is_impossible_labels)
total_loss, start_logits, end_logits, cls_logits = outputs (total_loss,) = outputs
outputs = model(input_ids, start_positions=sequence_labels, outputs = model(input_ids, start_positions=sequence_labels,
end_positions=sequence_labels) end_positions=sequence_labels)
total_loss, start_logits, end_logits = outputs (total_loss,) = outputs
result = { result = {
"loss": total_loss, "loss": total_loss,
"start_logits": start_logits, "start_top_log_probs": start_top_log_probs,
"end_logits": end_logits, "start_top_index": start_top_index,
"end_top_log_probs": end_top_log_probs,
"end_top_index": end_top_index,
"cls_logits": cls_logits, "cls_logits": cls_logits,
} }
...@@ -209,11 +211,17 @@ class XLMModelTest(CommonTestCases.CommonModelTester): ...@@ -209,11 +211,17 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
list(result["loss"].size()), list(result["loss"].size()),
[]) [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["start_logits"].size()), list(result["start_top_log_probs"].size()),
[self.batch_size, self.seq_length]) [self.batch_size, model.config.start_n_top])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["end_logits"].size()), list(result["start_top_index"].size()),
[self.batch_size, self.seq_length]) [self.batch_size, model.config.start_n_top])
self.parent.assertListEqual(
list(result["end_top_log_probs"].size()),
[self.batch_size, model.config.start_n_top * model.config.end_n_top])
self.parent.assertListEqual(
list(result["end_top_index"].size()),
[self.batch_size, model.config.start_n_top * model.config.end_n_top])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["cls_logits"].size()), list(result["cls_logits"].size()),
[self.batch_size]) [self.batch_size])
......
...@@ -210,17 +210,19 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -210,17 +210,19 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
cls_index=sequence_labels, cls_index=sequence_labels,
is_impossible=is_impossible_labels) is_impossible=is_impossible_labels)
total_loss, start_logits, end_logits, cls_logits, mems = outputs total_loss, mems = outputs
outputs = model(input_ids_1, start_positions=sequence_labels, outputs = model(input_ids_1, start_positions=sequence_labels,
end_positions=sequence_labels) end_positions=sequence_labels)
total_loss, start_logits, end_logits, mems = outputs total_loss, mems = outputs
result = { result = {
"loss": total_loss, "loss": total_loss,
"start_logits": start_logits, "start_top_log_probs": start_top_log_probs,
"end_logits": end_logits, "start_top_index": start_top_index,
"end_top_log_probs": end_top_log_probs,
"end_top_index": end_top_index,
"cls_logits": cls_logits, "cls_logits": cls_logits,
"mems": mems, "mems": mems,
} }
...@@ -229,11 +231,17 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -229,11 +231,17 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
list(result["loss"].size()), list(result["loss"].size()),
[]) [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["start_logits"].size()), list(result["start_top_log_probs"].size()),
[self.batch_size, self.seq_length]) [self.batch_size, model.config.start_n_top])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["end_logits"].size()), list(result["start_top_index"].size()),
[self.batch_size, self.seq_length]) [self.batch_size, model.config.start_n_top])
self.parent.assertListEqual(
list(result["end_top_log_probs"].size()),
[self.batch_size, model.config.start_n_top * model.config.end_n_top])
self.parent.assertListEqual(
list(result["end_top_index"].size()),
[self.batch_size, model.config.start_n_top * model.config.end_n_top])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["cls_logits"].size()), list(result["cls_logits"].size()),
[self.batch_size]) [self.batch_size])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment