Unverified Commit 9681f052 authored by Jiahao Li's avatar Jiahao Li Committed by GitHub
Browse files

Fix result saving errors of pytorch examples (#20276)

parent e627e9b5
...@@ -571,9 +571,9 @@ def main(): ...@@ -571,9 +571,9 @@ def main():
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
if args.output_dir is not None: all_results = {f"eval_{k}": v for k, v in eval_metric.items()}
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump({"eval_accuracy": eval_metric["accuracy"]}, f) json.dump(all_results, f)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -666,8 +666,8 @@ def main(): ...@@ -666,8 +666,8 @@ def main():
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump({"perplexity": perplexity}, f) json.dump({"perplexity": perplexity}, f)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -711,8 +711,8 @@ def main(): ...@@ -711,8 +711,8 @@ def main():
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump({"perplexity": perplexity}, f) json.dump({"perplexity": perplexity}, f)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -85,7 +85,7 @@ def parse_args(): ...@@ -85,7 +85,7 @@ def parse_args():
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
) )
parser.add_argument( parser.add_argument(
"--max_length", "--max_seq_length",
type=int, type=int,
default=128, default=128,
help=( help=(
...@@ -424,7 +424,7 @@ def main(): ...@@ -424,7 +424,7 @@ def main():
tokenized_examples = tokenizer( tokenized_examples = tokenizer(
first_sentences, first_sentences,
second_sentences, second_sentences,
max_length=args.max_length, max_length=args.max_seq_length,
padding=padding, padding=padding,
truncation=True, truncation=True,
) )
...@@ -654,8 +654,10 @@ def main(): ...@@ -654,8 +654,10 @@ def main():
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump({"eval_accuracy": eval_metric["accuracy"]}, f) all_results = {f"eval_{k}": v for k, v in eval_metric.items()}
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump(all_results, f)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -681,8 +681,9 @@ def main(): ...@@ -681,8 +681,9 @@ def main():
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
all_results = {f"eval_{k}": v for k, v in eval_metrics.items()}
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump({"eval_overall_accuracy": eval_metrics["overall_accuracy"]}, f) json.dump(all_results, f)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -747,16 +747,10 @@ def main(): ...@@ -747,16 +747,10 @@ def main():
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump( all_results = {f"eval_{k}": v for k, v in result.items()}
{ with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
"eval_rouge1": result["rouge1"], json.dump(all_results, f)
"eval_rouge2": result["rouge2"],
"eval_rougeL": result["rougeL"],
"eval_rougeLsum": result["rougeLsum"],
},
f,
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -625,8 +625,9 @@ def main(): ...@@ -625,8 +625,9 @@ def main():
logger.info(f"mnli-mm: {eval_metric}") logger.info(f"mnli-mm: {eval_metric}")
if args.output_dir is not None: if args.output_dir is not None:
all_results = {f"eval_{k}": v for k, v in eval_metric.items()}
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump({"eval_accuracy": eval_metric["accuracy"]}, f) json.dump(all_results, f)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -766,10 +766,11 @@ def main(): ...@@ -766,10 +766,11 @@ def main():
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: all_results = {f"eval_{k}": v for k, v in eval_metric.items()}
json.dump( if args.with_tracking:
{"eval_accuracy": eval_metric["accuracy"], "train_loss": total_loss.item() / len(train_dataloader)}, f all_results.update({"train_loss": total_loss.item() / len(train_dataloader)})
) with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump(all_results, f)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment