Commit 5adb39e7 authored by Marianne Stecklina's avatar Marianne Stecklina Committed by thomwolf
Browse files

Add option to predict on test set

parent 99b189df
...@@ -148,7 +148,7 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id): ...@@ -148,7 +148,7 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
# Log metrics # Log metrics
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
results = evaluate(args, model, tokenizer, labels, pad_token_label_id) results, _ = evaluate(args, model, tokenizer, labels, pad_token_label_id)
for key, value in results.items(): for key, value in results.items():
tb_writer.add_scalar("eval_{}".format(key), value, global_step) tb_writer.add_scalar("eval_{}".format(key), value, global_step)
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
...@@ -178,8 +178,8 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id): ...@@ -178,8 +178,8 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
return global_step, tr_loss / global_step return global_step, tr_loss / global_step
def evaluate(args, model, tokenizer, labels, pad_token_label_id, prefix=""): def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""):
eval_dataset = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, evaluate=True) eval_dataset = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode=mode)
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
# Note that DistributedSampler samples randomly # Note that DistributedSampler samples randomly
...@@ -241,15 +241,15 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, prefix=""): ...@@ -241,15 +241,15 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, prefix=""):
for key in sorted(results.keys()): for key in sorted(results.keys()):
logger.info(" %s = %s", key, str(results[key])) logger.info(" %s = %s", key, str(results[key]))
return results return results, preds_list
def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, evaluate=False): def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode):
if args.local_rank not in [-1, 0] and not evaluate: if args.local_rank not in [-1, 0] and not evaluate:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Load data features from cache or dataset file # Load data features from cache or dataset file
cached_features_file = os.path.join(args.data_dir, "cached_{}_{}_{}".format("dev" if evaluate else "train", cached_features_file = os.path.join(args.data_dir, "cached_{}_{}_{}".format(mode,
list(filter(None, args.model_name_or_path.split("/"))).pop(), list(filter(None, args.model_name_or_path.split("/"))).pop(),
str(args.max_seq_length))) str(args.max_seq_length)))
if os.path.exists(cached_features_file): if os.path.exists(cached_features_file):
...@@ -257,7 +257,7 @@ def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, evaluat ...@@ -257,7 +257,7 @@ def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, evaluat
features = torch.load(cached_features_file) features = torch.load(cached_features_file)
else: else:
logger.info("Creating features from dataset file at %s", args.data_dir) logger.info("Creating features from dataset file at %s", args.data_dir)
examples = read_examples_from_file(args.data_dir, evaluate=evaluate) examples = read_examples_from_file(args.data_dir, mode)
features = convert_examples_to_features(examples, labels, args.max_seq_length, tokenizer, features = convert_examples_to_features(examples, labels, args.max_seq_length, tokenizer,
cls_token_at_end=bool(args.model_type in ["xlnet"]), cls_token_at_end=bool(args.model_type in ["xlnet"]),
# xlnet has a cls token at the end # xlnet has a cls token at the end
...@@ -318,6 +318,8 @@ def main(): ...@@ -318,6 +318,8 @@ def main():
help="Whether to run training.") help="Whether to run training.")
parser.add_argument("--do_eval", action="store_true", parser.add_argument("--do_eval", action="store_true",
help="Whether to run eval on the dev set.") help="Whether to run eval on the dev set.")
parser.add_argument("--do_predict", action="store_true",
help="Whether to run predictions on the test set.")
parser.add_argument("--evaluate_during_training", action="store_true", parser.add_argument("--evaluate_during_training", action="store_true",
help="Whether to run evaluation during training at each logging step.") help="Whether to run evaluation during training at each logging step.")
parser.add_argument("--do_lower_case", action="store_true", parser.add_argument("--do_lower_case", action="store_true",
...@@ -433,7 +435,7 @@ def main(): ...@@ -433,7 +435,7 @@ def main():
# Training # Training
if args.do_train: if args.do_train:
train_dataset = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, evaluate=False) train_dataset = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode="train")
global_step, tr_loss = train(args, train_dataset, model, tokenizer, labels, pad_token_label_id) global_step, tr_loss = train(args, train_dataset, model, tokenizer, labels, pad_token_label_id)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
...@@ -466,7 +468,7 @@ def main(): ...@@ -466,7 +468,7 @@ def main():
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
model = model_class.from_pretrained(checkpoint) model = model_class.from_pretrained(checkpoint)
model.to(args.device) model.to(args.device)
result = evaluate(args, model, tokenizer, labels, pad_token_label_id, prefix=global_step) result, _ = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="dev", prefix=global_step)
if global_step: if global_step:
result = {"{}_{}".format(global_step, k): v for k, v in result.items()} result = {"{}_{}".format(global_step, k): v for k, v in result.items()}
results.update(result) results.update(result)
...@@ -475,6 +477,32 @@ def main(): ...@@ -475,6 +477,32 @@ def main():
for key in sorted(results.keys()): for key in sorted(results.keys()):
writer.write("{} = {}\n".format(key, str(results[key]))) writer.write("{} = {}\n".format(key, str(results[key])))
if args.do_predict and args.local_rank in [-1, 0]:
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
model = model_class.from_pretrained(args.output_dir)
model.to(args.device)
result, predictions = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="test")
# Save results
output_test_results_file = os.path.join(args.output_dir, "test_results.txt")
with open(output_test_results_file, "w") as writer:
for key in sorted(result.keys()):
writer.write("{} = {}\n".format(key, str(result[key])))
# Save predictions
output_test_predictions_file = os.path.join(args.output_dir, "test_predictions.txt")
with open(output_test_predictions_file, "w") as writer:
with open(os.path.join(args.data_dir, "test.txt"), "r") as f:
example_id = 0
for line in f:
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
writer.write(line)
if not predictions[example_id]:
example_id += 1
elif predictions[example_id]:
output_line = line.split()[0] + " " + predictions[example_id].pop(0) + "\n"
writer.write(output_line)
else:
logger.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0])
return results return results
......
...@@ -51,13 +51,8 @@ class InputFeatures(object): ...@@ -51,13 +51,8 @@ class InputFeatures(object):
self.label_ids = label_ids self.label_ids = label_ids
def read_examples_from_file(data_dir, evaluate=False): def read_examples_from_file(data_dir, mode):
if evaluate: file_path = os.path.join(data_dir, "{}.txt".format(mode))
file_path = os.path.join(data_dir, "dev.txt")
guid_prefix = "dev"
else:
file_path = os.path.join(data_dir, "train.txt")
guid_prefix = "train"
guid_index = 1 guid_index = 1
examples = [] examples = []
with open(file_path, encoding="utf-8") as f: with open(file_path, encoding="utf-8") as f:
...@@ -66,7 +61,7 @@ def read_examples_from_file(data_dir, evaluate=False): ...@@ -66,7 +61,7 @@ def read_examples_from_file(data_dir, evaluate=False):
for line in f: for line in f:
if line.startswith("-DOCSTART-") or line == "" or line == "\n": if line.startswith("-DOCSTART-") or line == "" or line == "\n":
if words: if words:
examples.append(InputExample(guid="{}-{}".format(guid_prefix, guid_index), examples.append(InputExample(guid="{}-{}".format(mode, guid_index),
words=words, words=words,
labels=labels)) labels=labels))
guid_index += 1 guid_index += 1
...@@ -75,9 +70,13 @@ def read_examples_from_file(data_dir, evaluate=False): ...@@ -75,9 +70,13 @@ def read_examples_from_file(data_dir, evaluate=False):
else: else:
splits = line.split(" ") splits = line.split(" ")
words.append(splits[0]) words.append(splits[0])
if len(splits) > 1:
labels.append(splits[-1].replace("\n", "")) labels.append(splits[-1].replace("\n", ""))
else:
# Examples could have no label for mode = "test"
labels.append("O")
if words: if words:
examples.append(InputExample(guid="%s-%d".format(guid_prefix, guid_index), examples.append(InputExample(guid="%s-%d".format(mode, guid_index),
words=words, words=words,
labels=labels)) labels=labels))
return examples return examples
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment