Commit 9093b1a6 authored by Baber's avatar Baber
Browse files

move metric calculation to task

parent 51ab86ff
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Union
# @dataclass
......@@ -64,11 +64,11 @@ class MetricResult:
Outputs for the metric function.
"""
doc_id: str | int | None
scores: list[dict[str, float]] | dict
doc_id: Union[str, int]
filter_key: str = None
metric_name: str = None
metadata: Optional[dict] = None
scores: Union[list[dict[str, float]], dict] = None
def __iter__(self):
if self.scores is None:
......
......@@ -8,8 +8,6 @@ from collections.abc import Callable
from copy import deepcopy
from dataclasses import asdict, dataclass
from inspect import getsource
from itertools import groupby
from operator import attrgetter
from typing import (
Any,
Dict,
......@@ -30,7 +28,12 @@ from tqdm import tqdm
from lm_eval import utils
from lm_eval.api import samplers
from lm_eval.api.instance import Instance, OutputType
from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
from lm_eval.api.metrics import (
bits_per_byte,
mean,
stderr_for_metric,
weighted_perplexity,
)
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
......@@ -39,10 +42,10 @@ from lm_eval.api.registry import (
get_metric_aggregation,
is_higher_better,
)
from lm_eval.api.schemas import MetricResult
from lm_eval.caching.cache import load_from_cache, save_to_cache
from lm_eval.filters import build_filter_ensemble
from lm_eval.prompts import get_prompt
from lm_eval.utils import create_sample_log, pass_at_k
ALL_OUTPUT_TYPES = [
......@@ -1774,19 +1777,22 @@ class ConfigurableTask(Task):
def calculate_metrics(
self,
instances_by_doc_id=None,
filter_keys=None,
samples=None,
rank=1,
limit=None,
world_size=1,
) -> list[MetricResult]:
requests: list[Instance] = None,
filter_keys: list[str] = None,
indices: list[int] = None,
rank: int = 1,
limit: int = None,
world_size: int = 1,
log_samples: bool = False,
) -> tuple[
Optional[dict[tuple[str, str], list[list[float]]]], Optional[list[dict]]
]:
"""Calculate metrics for all datapoints in the task.
Args:
instances_by_doc_id (dict): Dictionary mapping doc_ids to lists of instances.
filter_key (str): The filter key to use for filtered responses.
samples (dict, optional): Dictionary of sample indices to evaluate.
indices (dict, optional): Dictionary of sample indices to evaluate.
rank (int): The process rank.
limit (int, optional): Limit on number of examples to evaluate.
world_size (int): Total number of processes.
......@@ -1794,8 +1800,20 @@ class ConfigurableTask(Task):
Returns:
list: A list of metrics calculated for each document.
"""
if not self._instances:
return
if not requests and not self.instances:
print("sent results")
return None, None
### Collect values of metrics on all datapoints ###
# Pre-process task.instances to group by doc_id
instances_by_doc_id = defaultdict(list)
for instance in self.instances:
instances_by_doc_id[instance.doc_id].append(instance)
# Sort instances within each group
for instances in instances_by_doc_id.values():
instances.sort(key=lambda x: x.idx)
_all_metrics = defaultdict(list)
_samples = [] if log_samples else None
if filter_keys is None:
filter_keys = (
......@@ -1805,23 +1823,18 @@ class ConfigurableTask(Task):
)
if isinstance(filter_keys, str):
filter_keys = [filter_keys]
if not instances_by_doc_id:
instances_by_doc_id = defaultdict(list)
for instance in self.instances:
instances_by_doc_id[instance.doc_id].append(instance)
# all_metrics = collections.defaultdict(list)
all_metrics = []
for filter_key in filter_keys:
doc_iterator = self.doc_iterator(
rank=rank,
limit=limit,
world_size=world_size,
# samples=indices,
samples=indices,
)
for doc_id, doc in doc_iterator:
# doc_id_true = indices[doc_id] if indices else doc_id
requests = instances_by_doc_id[doc_id]
_doc_id_true = indices[doc_id] if indices else doc_id
_sample_metric = defaultdict(list)
requests = instances_by_doc_id[_doc_id_true]
if len(requests) > 1:
# if one doc has multiple instances then calculate metric together
metrics = self.process_results(
......@@ -1837,24 +1850,54 @@ class ConfigurableTask(Task):
else [req.filtered_resps[filter_key]]
)
]
all_metrics.append(
MetricResult(scores=metrics, doc_id=doc_id, filter_key=filter_key)
)
for metric in metrics:
for k, v in metric.items():
_sample_metric[k].append(v)
if log_samples:
_samples.append(
create_sample_log(
doc=doc,
doc_id=_doc_id_true,
target=self.doc_to_target(doc),
requests=requests,
metric_names=metrics,
filter_key=filter_key,
)
)
for metric_name, _score in _sample_metric.items():
_all_metrics[(metric_name, filter_key)].append(_score)
return all_metrics
return _all_metrics, _samples
@staticmethod
def compute_agg_metrics(self, metric_results: list[MetricResult]):
y_sorted = sorted(metric_results, key=attrgetter("filter_key", "metric_name"))
groups = {
key: list(
map(list, zip(*((d[it.metric_name] for d in it.scores) for it in g)))
)
for key, g in groupby(y_sorted, key=attrgetter("filter_key", "metric_name"))
}
def compute_agg_metrics(
self,
metric_results: dict[tuple[str, str], list[list[float]]],
bootstrap_iters: int = 1000,
):
agg_metrics = defaultdict(list)
for (metric_name, filter_key), scores in metric_results.items():
agg_fn = self.aggregation()[metric_name]
metric_key = f"{metric_name},{filter_key}"
self.repeat_metric = pass_at_k
repeats = [
self.repeat_metric(len(x), x.count(1), k=x.count(1) - 1) for x in scores
]
repeat_agg = np.mean(repeats)
agg_metrics[metric_key] = [agg_fn(items) for items in zip(*scores)]
if isinstance(bootstrap_iters, int):
stderr_fn = stderr_for_metric(
metric=agg_fn,
bootstrap_iters=min(bootstrap_iters, 100)
if metric_name in ["bleu", "chrf", "ter"]
else bootstrap_iters,
)
agg_metrics[f"{metric_name}_stderr,{filter_key}"] = [
(stderr_fn(item) if (stderr_fn and len(item) > 1) else "N/A")
for item in zip(*scores)
][0]
agg_metrics[f"{metric_key}_repeat"] = [repeat_agg]
return groups
return agg_metrics
class MultipleChoiceTask(Task):
......
import itertools
import json
import logging
import random
import time
......@@ -28,8 +27,6 @@ from lm_eval.loggers import EvaluationTracker
from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash
from lm_eval.tasks import TaskManager, get_task_dict
from lm_eval.utils import (
handle_non_serializable,
hash_string,
positional_deprecated,
setup_logging,
simple_parse_args_string,
......@@ -592,135 +589,23 @@ def evaluate(
# # unpack results and sort back in order and return control to Task
# TODO: make it possible to use a different metric per filter
# Pre-process task.instances to group by doc_id
instances_by_doc_id = defaultdict(list)
for instance in task.instances:
instances_by_doc_id[instance.doc_id].append(instance)
# Sort instances within each group
for instances in instances_by_doc_id.values():
instances.sort(key=lambda x: x.idx)
# instances_by_doc_id = defaultdict(list)
# for instance in task.instances:
# instances_by_doc_id[instance.doc_id].append(instance)
# # Sort instances within each group
# for instances in instances_by_doc_id.values():
# instances.sort(key=lambda x: x.idx)
# iterate over different filters used
if hasattr(task, "calculate_metrics"):
metrics = task.calculate_metrics(
instances_by_doc_id=instances_by_doc_id,
samples=samples,
rank=RANK,
limit=limit,
world_size=WORLD_SIZE,
)
for filter_key in task.instances[0].filtered_resps.keys():
if hasattr(task, "calculate_metrics"):
# Add sample logging here too - similar to what's done in the else branch
if log_samples:
indices = (
samples.get(task_output.task_name, None)
if samples is not None
else None
)
doc_iterator = task.doc_iterator(
rank=RANK,
limit=limit,
world_size=WORLD_SIZE,
samples=indices,
)
for doc_id, doc in doc_iterator:
doc_id_true = indices[doc_id] if indices else doc_id
requests = instances_by_doc_id[doc_id]
if requests: # Make sure there are requests for this doc_id
doc_metrics = metrics[filter_key][doc_id_true].metric_keys
target = task.doc_to_target(doc)
example = {
"doc_id": doc_id_true,
"doc": doc,
"target": target,
"arguments": [req.args for req in requests],
"resps": [req.resps for req in requests],
"filtered_resps": [
req.filtered_resps[filter_key] for req in requests
],
"filter": filter_key,
"metrics": doc_metrics,
"doc_hash": hash_string(
json.dumps(
requests[0].doc,
indent=2,
default=handle_non_serializable,
ensure_ascii=False,
)
),
"prompt_hash": hash_string(requests[0].arguments[0]),
"target_hash": hash_string(str(target)),
}
example.update(
{
metrics[filter_key][doc_id_true].metric_keys[
0
]: metrics[filter_key][doc_id_true]
}
)
task_output.logged_samples.append(example)
# Process all metrics returned from calculate_metrics
for filter_key in metrics:
for m_samples in metrics[filter_key]:
for metric, value in m_samples:
task_output.sample_metrics[(metric, filter_key)].append(
value
)
else:
# Fall back to the original approach for non-ConfigurableTask instances
indices = (
samples.get(task_output.task_name, None)
if samples is not None
else None
)
doc_iterator = task.doc_iterator(
rank=RANK,
limit=limit,
world_size=WORLD_SIZE,
samples=indices,
)
for doc_id, doc in doc_iterator:
if indices:
doc_id_true = indices[doc_id]
else:
doc_id_true = doc_id
requests = instances_by_doc_id[doc_id]
metrics: list[dict] = [
task.process_results(doc, response)
for req in requests
for response in req.filtered_resps[filter_key]
]
if log_samples:
target = task.doc_to_target(doc)
example = {
"doc_id": doc_id_true,
"doc": doc,
"target": target,
"arguments": [req.args for req in requests],
"resps": [req.resps for req in requests],
"filtered_resps": [
req.filtered_resps[filter_key] for req in requests
],
"filter": filter_key,
"metrics": metrics,
"doc_hash": hash_string(
json.dumps(
requests[0].doc,
indent=2,
default=handle_non_serializable,
ensure_ascii=False,
)
),
"prompt_hash": hash_string(requests[0].arguments[0]),
"target_hash": hash_string(str(target)),
}
example.update({"metrics": metrics})
task_output.logged_samples.append(example)
for x in metrics:
for metric, value in x.items():
task_output.sample_metrics[(metric, filter_key)].append(
value
)
_metrics, samples = task.calculate_metrics(
indices=samples,
rank=RANK,
limit=limit,
world_size=WORLD_SIZE,
)
task_output.agg_metrics = task.compute_agg_metrics(_metrics)
task_output.sample_metrics = _metrics
if log_samples:
task_output.logged_samples = samples
if WORLD_SIZE > 1:
# if multigpu, then gather data across all ranks to rank 0
......@@ -756,8 +641,8 @@ def evaluate(
if RANK == 0:
### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs
for task_output in eval_tasks:
task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters)
# for task_output in eval_tasks:
# task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters)
(
results,
samples,
......
......@@ -20,7 +20,7 @@ class TakeFirstFilter(Filter):
"""
Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
"""
return map(lambda r: r[0], resps)
return map(lambda r: r, resps)
@register_filter("take_first_k")
......
......@@ -17,6 +17,8 @@ import numpy as np
import yaml
from jinja2 import BaseLoader, Environment, StrictUndefined
from lm_eval.api.instance import Instance
SPACING = " " * 47
......@@ -406,9 +408,13 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False)
v = "%.4f" % v if isinstance(v, float) else v
if m + "_stderr" + "," + f in dic:
se = dic[m + "_stderr" + "," + f]
se = " N/A" if se == "N/A" else "%.4f" % se
values.append([k, version, f, n, m, hib, v, "±", se])
try:
se = dic[m + "_stderr" + "," + f]
se = " N/A" if se == "N/A" else "%.4f" % se
values.append([k, version, f, n, m, hib, v, "±", se])
except: # noqa: E722
values.append([k, version, f, n, m, hib, v, "", ""])
else:
values.append([k, version, f, n, m, hib, v, "", ""])
k = ""
......@@ -550,3 +556,44 @@ def weighted_f1_score(items):
preds = unzipped_list[1]
fscore = f1_score(golds, preds, average="weighted")
return fscore
def create_sample_log(
doc: dict,
doc_id: int,
target: Any,
requests: list[Instance],
metric_names: [dict],
filter_key: str,
) -> dict:
return {
"doc_id": doc_id,
"doc": doc,
"target": target,
"arguments": [req.args for req in requests],
"resps": [req.resps for req in requests],
"filtered_resps": [req.filtered_resps[filter_key] for req in requests],
"filter": filter_key,
"metrics": metric_names,
"doc_hash": hash_string(
json.dumps(
requests[0].doc,
indent=2,
default=handle_non_serializable,
ensure_ascii=False,
)
),
"prompt_hash": hash_string(requests[0].arguments[0]),
"target_hash": hash_string(str(target)),
}
def pass_at_k(n: int, c: int, k: int) -> float:
"""
:param n: total number of samples
:param c: number of correct samples
:param k: k in pass@$k$
"""
if n - c < k:
return 1.0
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment