Commit 90cf3b89 authored by Baber's avatar Baber
Browse files

rename

parent c407be5b
...@@ -1777,7 +1777,7 @@ class ConfigurableTask(Task): ...@@ -1777,7 +1777,7 @@ class ConfigurableTask(Task):
f"num_samples={len(self.eval_docs)})" f"num_samples={len(self.eval_docs)})"
) )
def calculate_metrics( def compute_sample_metrics(
self, self,
requests: list[Instance] = None, requests: list[Instance] = None,
filter_keys: list[str] = None, filter_keys: list[str] = None,
...@@ -1804,11 +1804,13 @@ class ConfigurableTask(Task): ...@@ -1804,11 +1804,13 @@ class ConfigurableTask(Task):
""" """
if not requests and not self.instances: if not requests and not self.instances:
return None, None return None, None
else:
requests = requests if requests else self.instances
### Collect values of metrics on all datapoints ### ### Collect values of metrics on all datapoints ###
# Pre-process task.instances to group by doc_id # Pre-process task.instances to group by doc_id
instances_by_doc_id = defaultdict(list) instances_by_doc_id = defaultdict(list)
for instance in self.instances: for instance in requests:
instances_by_doc_id[instance.doc_id].append(instance) instances_by_doc_id[instance.doc_id].append(instance)
# Sort instances within each group # Sort instances within each group
for instances in instances_by_doc_id.values(): for instances in instances_by_doc_id.values():
......
...@@ -588,11 +588,12 @@ def evaluate( ...@@ -588,11 +588,12 @@ def evaluate(
### Collect values of metrics on all datapoints ### ### Collect values of metrics on all datapoints ###
# # unpack results and sort back in order and return control to Task # # unpack results and sort back in order and return control to Task
# TODO: make it possible to use a different metric per filter # TODO: make it possible to use a different metric per filter
_metrics, samples = task.calculate_metrics( _metrics, samples = task.compute_sample_metrics(
indices=samples, indices=samples,
rank=RANK, rank=RANK,
limit=limit, limit=limit,
world_size=WORLD_SIZE, world_size=WORLD_SIZE,
log_samples=log_samples,
) )
task_output.sample_metrics = _metrics task_output.sample_metrics = _metrics
if log_samples: if log_samples:
......
...@@ -25,9 +25,9 @@ from lm_eval.utils import ( ...@@ -25,9 +25,9 @@ from lm_eval.utils import (
get_sample_results_filenames, get_sample_results_filenames,
handle_non_serializable, handle_non_serializable,
hash_string, hash_string,
sanitize_list,
sanitize_model_name, sanitize_model_name,
sanitize_task_name, sanitize_task_name,
serialize_list,
) )
...@@ -295,7 +295,7 @@ class EvaluationTracker: ...@@ -295,7 +295,7 @@ class EvaluationTracker:
""" """
if self.output_path: if self.output_path:
try: try:
eval_logger.info(f"Saving per-sample results for: {task_name}") eval_logger.debug(f"Saving per-sample results for: {task_name}")
path = Path(self.output_path if self.output_path else Path.cwd()) path = Path(self.output_path if self.output_path else Path.cwd())
if path.suffix == ".json": if path.suffix == ".json":
...@@ -309,7 +309,7 @@ class EvaluationTracker: ...@@ -309,7 +309,7 @@ class EvaluationTracker:
file_results_samples = path.joinpath( file_results_samples = path.joinpath(
f"samples_{task_name}_{self.date_id}.jsonl" f"samples_{task_name}_{self.date_id}.jsonl"
) )
with file_results_samples.open("a", encoding="utf-8") as f:
for sample in samples: for sample in samples:
# we first need to sanitize arguments and resps # we first need to sanitize arguments and resps
# otherwise we won't be able to load the dataset # otherwise we won't be able to load the dataset
...@@ -320,8 +320,10 @@ class EvaluationTracker: ...@@ -320,8 +320,10 @@ class EvaluationTracker:
for j, tmp in enumerate(arg): for j, tmp in enumerate(arg):
arguments[f"gen_args_{i}"][f"arg_{j}"] = tmp arguments[f"gen_args_{i}"][f"arg_{j}"] = tmp
sample["resps"] = sanitize_list(sample["resps"]) sample["resps"] = serialize_list(sample["resps"])
sample["filtered_resps"] = sanitize_list(sample["filtered_resps"]) sample["filtered_resps"] = serialize_list(
sample["filtered_resps"]
)
sample["arguments"] = arguments sample["arguments"] = arguments
sample["target"] = str(sample["target"]) sample["target"] = str(sample["target"])
...@@ -334,7 +336,6 @@ class EvaluationTracker: ...@@ -334,7 +336,6 @@ class EvaluationTracker:
+ "\n" + "\n"
) )
with open(file_results_samples, "a", encoding="utf-8") as f:
f.write(sample_dump) f.write(sample_dump)
if self.api and self.push_samples_to_hub: if self.api and self.push_samples_to_hub:
...@@ -374,7 +375,7 @@ class EvaluationTracker: ...@@ -374,7 +375,7 @@ class EvaluationTracker:
) )
except Exception as e: except Exception as e:
eval_logger.warning("Could not save sample results") eval_logger.warning(f"Could not save sample results for: {task_name}")
eval_logger.info(repr(e)) eval_logger.info(repr(e))
else: else:
eval_logger.info("Output path not provided, skipping saving sample results") eval_logger.info("Output path not provided, skipping saving sample results")
......
...@@ -118,14 +118,14 @@ def handle_non_serializable(o): ...@@ -118,14 +118,14 @@ def handle_non_serializable(o):
return str(o) return str(o)
def sanitize_list(sub): def serialize_list(sub):
""" """
Takes possible nested list and recursively converts all inner component to strings Takes possible nested list and recursively converts all inner component to strings
""" """
if isinstance(sub, list): if isinstance(sub, list):
return [sanitize_list(item) for item in sub] return [serialize_list(item) for item in sub]
if isinstance(sub, tuple): if isinstance(sub, tuple):
return tuple(sanitize_list(item) for item in sub) return tuple(serialize_list(item) for item in sub)
else: else:
return str(sub) return str(sub)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment