Unverified Commit 19fa01ce authored by Stefan Schweter's avatar Stefan Schweter Committed by GitHub
Browse files

token-classification: use is_world_process_zero instead of deprecated is_world_master() (#8828)

parent 40ecaf0c
...@@ -369,7 +369,7 @@ def main(): ...@@ -369,7 +369,7 @@ def main():
] ]
output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt") output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt")
if trainer.is_world_master(): if trainer.is_world_process_zero():
with open(output_test_results_file, "w") as writer: with open(output_test_results_file, "w") as writer:
for key, value in metrics.items(): for key, value in metrics.items():
logger.info(f" {key} = {value}") logger.info(f" {key} = {value}")
...@@ -377,7 +377,7 @@ def main(): ...@@ -377,7 +377,7 @@ def main():
# Save predictions # Save predictions
output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt") output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt")
if trainer.is_world_master(): if trainer.is_world_process_zero():
with open(output_test_predictions_file, "w") as writer: with open(output_test_predictions_file, "w") as writer:
for prediction in true_predictions: for prediction in true_predictions:
writer.write(" ".join(prediction) + "\n") writer.write(" ".join(prediction) + "\n")
......
...@@ -291,7 +291,7 @@ def main(): ...@@ -291,7 +291,7 @@ def main():
preds_list, _ = align_predictions(predictions, label_ids) preds_list, _ = align_predictions(predictions, label_ids)
output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt") output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt")
if trainer.is_world_master(): if trainer.is_world_process_zero():
with open(output_test_results_file, "w") as writer: with open(output_test_results_file, "w") as writer:
for key, value in metrics.items(): for key, value in metrics.items():
logger.info(" %s = %s", key, value) logger.info(" %s = %s", key, value)
...@@ -299,7 +299,7 @@ def main(): ...@@ -299,7 +299,7 @@ def main():
# Save predictions # Save predictions
output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt") output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt")
if trainer.is_world_master(): if trainer.is_world_process_zero():
with open(output_test_predictions_file, "w") as writer: with open(output_test_predictions_file, "w") as writer:
with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f: with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f:
token_classification_task.write_predictions_to_file(writer, f, preds_list) token_classification_task.write_predictions_to_file(writer, f, preds_list)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment