Unverified Commit 18f4eb57 authored by KonradSzafer's avatar KonradSzafer Committed by GitHub
Browse files

eval tracker args fix (#1777)

parent 59cf408a
...@@ -3,7 +3,6 @@ import json ...@@ -3,7 +3,6 @@ import json
import logging import logging
import os import os
import sys import sys
from argparse import Namespace
from functools import partial from functools import partial
from typing import Union from typing import Union
...@@ -261,15 +260,14 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -261,15 +260,14 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
eval_logger.info(f"Including path: {args.include_path}") eval_logger.info(f"Including path: {args.include_path}")
task_manager = TaskManager(args.verbosity, include_path=args.include_path) task_manager = TaskManager(args.verbosity, include_path=args.include_path)
evaluation_tracker_args = Namespace(**evaluation_tracker_args)
if ( if (
evaluation_tracker_args.push_results_to_hub "push_results_to_hub" in evaluation_tracker_args
or evaluation_tracker_args.push_samples_to_hub or "push_samples_to_hub" in evaluation_tracker_args
) and not evaluation_tracker_args.hub_results_org: ) and "hub_results_org" not in evaluation_tracker_args:
raise ValueError( raise ValueError(
"If push_results_to_hub or push_samples_to_hub is set, results_org must be specified." "If push_results_to_hub or push_samples_to_hub is set, results_org must be specified."
) )
if evaluation_tracker_args.push_samples_to_hub and not args.log_samples: if "push_samples_to_hub" in evaluation_tracker_args and not args.log_samples:
eval_logger.warning( eval_logger.warning(
"Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub." "Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub."
) )
...@@ -376,7 +374,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -376,7 +374,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
except Exception as e: except Exception as e:
eval_logger.info(f"Logging to Weights and Biases failed due to {e}") eval_logger.info(f"Logging to Weights and Biases failed due to {e}")
evaluation_tracker.save_results_aggregated(results=results, samples=samples) evaluation_tracker.save_results_aggregated(
results=results, samples=samples if args.log_samples else None
)
if args.log_samples: if args.log_samples:
for task_name, config in results["configs"].items(): for task_name, config in results["configs"].items():
......
...@@ -131,14 +131,15 @@ class EvaluationTracker: ...@@ -131,14 +131,15 @@ class EvaluationTracker:
try: try:
eval_logger.info("Saving results aggregated") eval_logger.info("Saving results aggregated")
# calculate cumulative hash for each task # calculate cumulative hash for each task - only if samples are provided
task_hashes = {} task_hashes = {}
for task_name, task_samples in samples.items(): if samples:
sample_hashes = [ for task_name, task_samples in samples.items():
s["doc_hash"] + s["prompt_hash"] + s["target_hash"] sample_hashes = [
for s in task_samples s["doc_hash"] + s["prompt_hash"] + s["target_hash"]
] for s in task_samples
task_hashes[task_name] = hash_string("".join(sample_hashes)) ]
task_hashes[task_name] = hash_string("".join(sample_hashes))
# update initial results dict # update initial results dict
results.update({"task_hashes": task_hashes}) results.update({"task_hashes": task_hashes})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment