Unverified Commit 18f4eb57 authored by KonradSzafer's avatar KonradSzafer Committed by GitHub
Browse files

eval tracker args fix (#1777)

parent 59cf408a
......@@ -3,7 +3,6 @@ import json
import logging
import os
import sys
from argparse import Namespace
from functools import partial
from typing import Union
......@@ -261,15 +260,14 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
eval_logger.info(f"Including path: {args.include_path}")
task_manager = TaskManager(args.verbosity, include_path=args.include_path)
evaluation_tracker_args = Namespace(**evaluation_tracker_args)
if (
evaluation_tracker_args.push_results_to_hub
or evaluation_tracker_args.push_samples_to_hub
) and not evaluation_tracker_args.hub_results_org:
"push_results_to_hub" in evaluation_tracker_args
or "push_samples_to_hub" in evaluation_tracker_args
) and "hub_results_org" not in evaluation_tracker_args:
raise ValueError(
"If push_results_to_hub or push_samples_to_hub is set, results_org must be specified."
)
if evaluation_tracker_args.push_samples_to_hub and not args.log_samples:
if "push_samples_to_hub" in evaluation_tracker_args and not args.log_samples:
eval_logger.warning(
"Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub."
)
......@@ -376,7 +374,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
except Exception as e:
eval_logger.info(f"Logging to Weights and Biases failed due to {e}")
evaluation_tracker.save_results_aggregated(results=results, samples=samples)
evaluation_tracker.save_results_aggregated(
results=results, samples=samples if args.log_samples else None
)
if args.log_samples:
for task_name, config in results["configs"].items():
......
......@@ -131,14 +131,15 @@ class EvaluationTracker:
try:
eval_logger.info("Saving results aggregated")
# calculate cumulative hash for each task
# calculate cumulative hash for each task - only if samples are provided
task_hashes = {}
for task_name, task_samples in samples.items():
sample_hashes = [
s["doc_hash"] + s["prompt_hash"] + s["target_hash"]
for s in task_samples
]
task_hashes[task_name] = hash_string("".join(sample_hashes))
if samples:
for task_name, task_samples in samples.items():
sample_hashes = [
s["doc_hash"] + s["prompt_hash"] + s["target_hash"]
for s in task_samples
]
task_hashes[task_name] = hash_string("".join(sample_hashes))
# update initial results dict
results.update({"task_hashes": task_hashes})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment