Commit 2195c0d5 authored by Brian Ma's avatar Brian Ma
Browse files

Evaluation result.txt path changing #1286

parent ebb32261
...@@ -248,7 +248,7 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -248,7 +248,7 @@ def evaluate(args, model, tokenizer, prefix=""):
result = compute_metrics(eval_task, preds, out_label_ids) result = compute_metrics(eval_task, preds, out_label_ids)
results.update(result) results.update(result)
output_eval_file = os.path.join(eval_output_dir, "eval_results.txt") output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
with open(output_eval_file, "w") as writer: with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(prefix)) logger.info("***** Eval results {} *****".format(prefix))
for key in sorted(result.keys()): for key in sorted(result.keys()):
...@@ -489,9 +489,11 @@ def main(): ...@@ -489,9 +489,11 @@ def main():
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
model = model_class.from_pretrained(checkpoint) model = model_class.from_pretrained(checkpoint)
model.to(args.device) model.to(args.device)
result = evaluate(args, model, tokenizer, prefix=global_step) result = evaluate(args, model, tokenizer, prefix=prefix)
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
results.update(result) results.update(result)
......
...@@ -282,7 +282,7 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -282,7 +282,7 @@ def evaluate(args, model, tokenizer, prefix=""):
"perplexity": perplexity "perplexity": perplexity
} }
output_eval_file = os.path.join(eval_output_dir, "eval_results.txt") output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
with open(output_eval_file, "w") as writer: with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(prefix)) logger.info("***** Eval results {} *****".format(prefix))
for key in sorted(result.keys()): for key in sorted(result.keys()):
...@@ -484,9 +484,11 @@ def main(): ...@@ -484,9 +484,11 @@ def main():
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
model = model_class.from_pretrained(checkpoint) model = model_class.from_pretrained(checkpoint)
model.to(args.device) model.to(args.device)
result = evaluate(args, model, tokenizer, prefix=global_step) result = evaluate(args, model, tokenizer, prefix=prefix)
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
results.update(result) results.update(result)
......
...@@ -512,9 +512,11 @@ def main(): ...@@ -512,9 +512,11 @@ def main():
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
model = model_class.from_pretrained(checkpoint) model = model_class.from_pretrained(checkpoint)
model.to(args.device) model.to(args.device)
result = evaluate(args, model, tokenizer, prefix=global_step) result = evaluate(args, model, tokenizer, prefix=prefix)
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
results.update(result) results.update(result)
...@@ -528,9 +530,11 @@ def main(): ...@@ -528,9 +530,11 @@ def main():
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
model = model_class.from_pretrained(checkpoint) model = model_class.from_pretrained(checkpoint)
model.to(args.device) model.to(args.device)
result = evaluate(args, model, tokenizer, prefix=global_step, test=True) result = evaluate(args, model, tokenizer, prefix=prefix, test=True)
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
results.update(result) results.update(result)
if best_steps: if best_steps:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment