Commit 90cf3b89 authored by Baber's avatar Baber
Browse files

rename

parent c407be5b
......@@ -1777,7 +1777,7 @@ class ConfigurableTask(Task):
f"num_samples={len(self.eval_docs)})"
)
def calculate_metrics(
def compute_sample_metrics(
self,
requests: list[Instance] = None,
filter_keys: list[str] = None,
......@@ -1804,11 +1804,13 @@ class ConfigurableTask(Task):
"""
if not requests and not self.instances:
return None, None
else:
requests = requests if requests else self.instances
### Collect values of metrics on all datapoints ###
# Pre-process task.instances to group by doc_id
instances_by_doc_id = defaultdict(list)
for instance in self.instances:
for instance in requests:
instances_by_doc_id[instance.doc_id].append(instance)
# Sort instances within each group
for instances in instances_by_doc_id.values():
......
......@@ -588,11 +588,12 @@ def evaluate(
### Collect values of metrics on all datapoints ###
# # unpack results and sort back in order and return control to Task
# TODO: make it possible to use a different metric per filter
_metrics, samples = task.calculate_metrics(
_metrics, samples = task.compute_sample_metrics(
indices=samples,
rank=RANK,
limit=limit,
world_size=WORLD_SIZE,
log_samples=log_samples,
)
task_output.sample_metrics = _metrics
if log_samples:
......
......@@ -25,9 +25,9 @@ from lm_eval.utils import (
get_sample_results_filenames,
handle_non_serializable,
hash_string,
sanitize_list,
sanitize_model_name,
sanitize_task_name,
serialize_list,
)
......@@ -295,7 +295,7 @@ class EvaluationTracker:
"""
if self.output_path:
try:
eval_logger.info(f"Saving per-sample results for: {task_name}")
eval_logger.debug(f"Saving per-sample results for: {task_name}")
path = Path(self.output_path if self.output_path else Path.cwd())
if path.suffix == ".json":
......@@ -309,32 +309,33 @@ class EvaluationTracker:
file_results_samples = path.joinpath(
f"samples_{task_name}_{self.date_id}.jsonl"
)
for sample in samples:
# we first need to sanitize arguments and resps
# otherwise we won't be able to load the dataset
# using the datasets library
arguments = {}
for i, arg in enumerate(sample["arguments"]):
arguments[f"gen_args_{i}"] = {}
for j, tmp in enumerate(arg):
arguments[f"gen_args_{i}"][f"arg_{j}"] = tmp
sample["resps"] = sanitize_list(sample["resps"])
sample["filtered_resps"] = sanitize_list(sample["filtered_resps"])
sample["arguments"] = arguments
sample["target"] = str(sample["target"])
sample_dump = (
json.dumps(
sample,
default=handle_non_serializable,
ensure_ascii=False,
with file_results_samples.open("a", encoding="utf-8") as f:
for sample in samples:
# we first need to sanitize arguments and resps
# otherwise we won't be able to load the dataset
# using the datasets library
arguments = {}
for i, arg in enumerate(sample["arguments"]):
arguments[f"gen_args_{i}"] = {}
for j, tmp in enumerate(arg):
arguments[f"gen_args_{i}"][f"arg_{j}"] = tmp
sample["resps"] = serialize_list(sample["resps"])
sample["filtered_resps"] = serialize_list(
sample["filtered_resps"]
)
sample["arguments"] = arguments
sample["target"] = str(sample["target"])
sample_dump = (
json.dumps(
sample,
default=handle_non_serializable,
ensure_ascii=False,
)
+ "\n"
)
+ "\n"
)
with open(file_results_samples, "a", encoding="utf-8") as f:
f.write(sample_dump)
if self.api and self.push_samples_to_hub:
......@@ -374,7 +375,7 @@ class EvaluationTracker:
)
except Exception as e:
eval_logger.warning("Could not save sample results")
eval_logger.warning(f"Could not save sample results for: {task_name}")
eval_logger.info(repr(e))
else:
eval_logger.info("Output path not provided, skipping saving sample results")
......
......@@ -118,14 +118,14 @@ def handle_non_serializable(o):
return str(o)
def sanitize_list(sub):
def serialize_list(sub):
"""
Takes possible nested list and recursively converts all inner component to strings
"""
if isinstance(sub, list):
return [sanitize_list(item) for item in sub]
return [serialize_list(item) for item in sub]
if isinstance(sub, tuple):
return tuple(sanitize_list(item) for item in sub)
return tuple(serialize_list(item) for item in sub)
else:
return str(sub)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment