Unverified Commit dd49404a authored by Hwijeen Ahn's avatar Hwijeen Ahn Committed by GitHub
Browse files

check if eval dataset is dict (#24877)

* check if eval dataset is dict

* formatting
parent 5c5cb4ee
...@@ -692,7 +692,13 @@ def main(): ...@@ -692,7 +692,13 @@ def main():
results = {} results = {}
if training_args.do_eval: if training_args.do_eval:
logger.info("*** Evaluate ***") logger.info("*** Evaluate ***")
metrics = trainer.evaluate(metric_key_prefix="eval") if isinstance(eval_dataset, dict):
metrics = {}
for eval_ds_name, eval_ds in eval_dataset.items():
dataset_metrics = trainer.evaluate(eval_dataset=eval_ds, metric_key_prefix=f"eval_{eval_ds_name}")
metrics.update(dataset_metrics)
else:
metrics = trainer.evaluate(metric_key_prefix="eval")
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment