Commit 9093b1a6 authored by Baber's avatar Baber
Browse files

move metric calculation to task

parent 51ab86ff
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional, Union
# @dataclass # @dataclass
...@@ -64,11 +64,11 @@ class MetricResult: ...@@ -64,11 +64,11 @@ class MetricResult:
Outputs for the metric function. Outputs for the metric function.
""" """
doc_id: str | int | None doc_id: Union[str, int]
scores: list[dict[str, float]] | dict
filter_key: str = None filter_key: str = None
metric_name: str = None metric_name: str = None
metadata: Optional[dict] = None metadata: Optional[dict] = None
scores: Union[list[dict[str, float]], dict] = None
def __iter__(self): def __iter__(self):
if self.scores is None: if self.scores is None:
......
...@@ -8,8 +8,6 @@ from collections.abc import Callable ...@@ -8,8 +8,6 @@ from collections.abc import Callable
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from inspect import getsource from inspect import getsource
from itertools import groupby
from operator import attrgetter
from typing import ( from typing import (
Any, Any,
Dict, Dict,
...@@ -30,7 +28,12 @@ from tqdm import tqdm ...@@ -30,7 +28,12 @@ from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
from lm_eval.api import samplers from lm_eval.api import samplers
from lm_eval.api.instance import Instance, OutputType from lm_eval.api.instance import Instance, OutputType
from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity from lm_eval.api.metrics import (
bits_per_byte,
mean,
stderr_for_metric,
weighted_perplexity,
)
from lm_eval.api.registry import ( from lm_eval.api.registry import (
AGGREGATION_REGISTRY, AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY, DEFAULT_METRIC_REGISTRY,
...@@ -39,10 +42,10 @@ from lm_eval.api.registry import ( ...@@ -39,10 +42,10 @@ from lm_eval.api.registry import (
get_metric_aggregation, get_metric_aggregation,
is_higher_better, is_higher_better,
) )
from lm_eval.api.schemas import MetricResult
from lm_eval.caching.cache import load_from_cache, save_to_cache from lm_eval.caching.cache import load_from_cache, save_to_cache
from lm_eval.filters import build_filter_ensemble from lm_eval.filters import build_filter_ensemble
from lm_eval.prompts import get_prompt from lm_eval.prompts import get_prompt
from lm_eval.utils import create_sample_log, pass_at_k
ALL_OUTPUT_TYPES = [ ALL_OUTPUT_TYPES = [
...@@ -1774,19 +1777,22 @@ class ConfigurableTask(Task): ...@@ -1774,19 +1777,22 @@ class ConfigurableTask(Task):
def calculate_metrics( def calculate_metrics(
self, self,
instances_by_doc_id=None, requests: list[Instance] = None,
filter_keys=None, filter_keys: list[str] = None,
samples=None, indices: list[int] = None,
rank=1, rank: int = 1,
limit=None, limit: int = None,
world_size=1, world_size: int = 1,
) -> list[MetricResult]: log_samples: bool = False,
) -> tuple[
Optional[dict[tuple[str, str], list[list[float]]]], Optional[list[dict]]
]:
"""Calculate metrics for all datapoints in the task. """Calculate metrics for all datapoints in the task.
Args: Args:
instances_by_doc_id (dict): Dictionary mapping doc_ids to lists of instances. instances_by_doc_id (dict): Dictionary mapping doc_ids to lists of instances.
filter_key (str): The filter key to use for filtered responses. filter_key (str): The filter key to use for filtered responses.
samples (dict, optional): Dictionary of sample indices to evaluate. indices (dict, optional): Dictionary of sample indices to evaluate.
rank (int): The process rank. rank (int): The process rank.
limit (int, optional): Limit on number of examples to evaluate. limit (int, optional): Limit on number of examples to evaluate.
world_size (int): Total number of processes. world_size (int): Total number of processes.
...@@ -1794,8 +1800,20 @@ class ConfigurableTask(Task): ...@@ -1794,8 +1800,20 @@ class ConfigurableTask(Task):
Returns: Returns:
list: A list of metrics calculated for each document. list: A list of metrics calculated for each document.
""" """
if not self._instances: if not requests and not self.instances:
return print("sent results")
return None, None
### Collect values of metrics on all datapoints ###
# Pre-process task.instances to group by doc_id
instances_by_doc_id = defaultdict(list)
for instance in self.instances:
instances_by_doc_id[instance.doc_id].append(instance)
# Sort instances within each group
for instances in instances_by_doc_id.values():
instances.sort(key=lambda x: x.idx)
_all_metrics = defaultdict(list)
_samples = [] if log_samples else None
if filter_keys is None: if filter_keys is None:
filter_keys = ( filter_keys = (
...@@ -1805,23 +1823,18 @@ class ConfigurableTask(Task): ...@@ -1805,23 +1823,18 @@ class ConfigurableTask(Task):
) )
if isinstance(filter_keys, str): if isinstance(filter_keys, str):
filter_keys = [filter_keys] filter_keys = [filter_keys]
if not instances_by_doc_id:
instances_by_doc_id = defaultdict(list)
for instance in self.instances:
instances_by_doc_id[instance.doc_id].append(instance)
# all_metrics = collections.defaultdict(list)
all_metrics = []
for filter_key in filter_keys: for filter_key in filter_keys:
doc_iterator = self.doc_iterator( doc_iterator = self.doc_iterator(
rank=rank, rank=rank,
limit=limit, limit=limit,
world_size=world_size, world_size=world_size,
# samples=indices, samples=indices,
) )
for doc_id, doc in doc_iterator: for doc_id, doc in doc_iterator:
# doc_id_true = indices[doc_id] if indices else doc_id _doc_id_true = indices[doc_id] if indices else doc_id
requests = instances_by_doc_id[doc_id] _sample_metric = defaultdict(list)
requests = instances_by_doc_id[_doc_id_true]
if len(requests) > 1: if len(requests) > 1:
# if one doc has multiple instances then calculate metric together # if one doc has multiple instances then calculate metric together
metrics = self.process_results( metrics = self.process_results(
...@@ -1837,24 +1850,54 @@ class ConfigurableTask(Task): ...@@ -1837,24 +1850,54 @@ class ConfigurableTask(Task):
else [req.filtered_resps[filter_key]] else [req.filtered_resps[filter_key]]
) )
] ]
all_metrics.append( for metric in metrics:
MetricResult(scores=metrics, doc_id=doc_id, filter_key=filter_key) for k, v in metric.items():
) _sample_metric[k].append(v)
if log_samples:
_samples.append(
create_sample_log(
doc=doc,
doc_id=_doc_id_true,
target=self.doc_to_target(doc),
requests=requests,
metric_names=metrics,
filter_key=filter_key,
)
)
for metric_name, _score in _sample_metric.items():
_all_metrics[(metric_name, filter_key)].append(_score)
return all_metrics return _all_metrics, _samples
@staticmethod def compute_agg_metrics(
def compute_agg_metrics(self, metric_results: list[MetricResult]): self,
y_sorted = sorted(metric_results, key=attrgetter("filter_key", "metric_name")) metric_results: dict[tuple[str, str], list[list[float]]],
bootstrap_iters: int = 1000,
groups = { ):
key: list( agg_metrics = defaultdict(list)
map(list, zip(*((d[it.metric_name] for d in it.scores) for it in g))) for (metric_name, filter_key), scores in metric_results.items():
) agg_fn = self.aggregation()[metric_name]
for key, g in groupby(y_sorted, key=attrgetter("filter_key", "metric_name")) metric_key = f"{metric_name},{filter_key}"
} self.repeat_metric = pass_at_k
repeats = [
self.repeat_metric(len(x), x.count(1), k=x.count(1) - 1) for x in scores
]
repeat_agg = np.mean(repeats)
agg_metrics[metric_key] = [agg_fn(items) for items in zip(*scores)]
if isinstance(bootstrap_iters, int):
stderr_fn = stderr_for_metric(
metric=agg_fn,
bootstrap_iters=min(bootstrap_iters, 100)
if metric_name in ["bleu", "chrf", "ter"]
else bootstrap_iters,
)
agg_metrics[f"{metric_name}_stderr,{filter_key}"] = [
(stderr_fn(item) if (stderr_fn and len(item) > 1) else "N/A")
for item in zip(*scores)
][0]
agg_metrics[f"{metric_key}_repeat"] = [repeat_agg]
return groups return agg_metrics
class MultipleChoiceTask(Task): class MultipleChoiceTask(Task):
......
import itertools import itertools
import json
import logging import logging
import random import random
import time import time
...@@ -28,8 +27,6 @@ from lm_eval.loggers import EvaluationTracker ...@@ -28,8 +27,6 @@ from lm_eval.loggers import EvaluationTracker
from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash
from lm_eval.tasks import TaskManager, get_task_dict from lm_eval.tasks import TaskManager, get_task_dict
from lm_eval.utils import ( from lm_eval.utils import (
handle_non_serializable,
hash_string,
positional_deprecated, positional_deprecated,
setup_logging, setup_logging,
simple_parse_args_string, simple_parse_args_string,
...@@ -592,135 +589,23 @@ def evaluate( ...@@ -592,135 +589,23 @@ def evaluate(
# # unpack results and sort back in order and return control to Task # # unpack results and sort back in order and return control to Task
# TODO: make it possible to use a different metric per filter # TODO: make it possible to use a different metric per filter
# Pre-process task.instances to group by doc_id # Pre-process task.instances to group by doc_id
instances_by_doc_id = defaultdict(list) # instances_by_doc_id = defaultdict(list)
for instance in task.instances: # for instance in task.instances:
instances_by_doc_id[instance.doc_id].append(instance) # instances_by_doc_id[instance.doc_id].append(instance)
# Sort instances within each group # # Sort instances within each group
for instances in instances_by_doc_id.values(): # for instances in instances_by_doc_id.values():
instances.sort(key=lambda x: x.idx) # instances.sort(key=lambda x: x.idx)
# iterate over different filters used # iterate over different filters used
if hasattr(task, "calculate_metrics"): _metrics, samples = task.calculate_metrics(
metrics = task.calculate_metrics( indices=samples,
instances_by_doc_id=instances_by_doc_id, rank=RANK,
samples=samples, limit=limit,
rank=RANK, world_size=WORLD_SIZE,
limit=limit, )
world_size=WORLD_SIZE, task_output.agg_metrics = task.compute_agg_metrics(_metrics)
) task_output.sample_metrics = _metrics
for filter_key in task.instances[0].filtered_resps.keys(): if log_samples:
if hasattr(task, "calculate_metrics"): task_output.logged_samples = samples
# Add sample logging here too - similar to what's done in the else branch
if log_samples:
indices = (
samples.get(task_output.task_name, None)
if samples is not None
else None
)
doc_iterator = task.doc_iterator(
rank=RANK,
limit=limit,
world_size=WORLD_SIZE,
samples=indices,
)
for doc_id, doc in doc_iterator:
doc_id_true = indices[doc_id] if indices else doc_id
requests = instances_by_doc_id[doc_id]
if requests: # Make sure there are requests for this doc_id
doc_metrics = metrics[filter_key][doc_id_true].metric_keys
target = task.doc_to_target(doc)
example = {
"doc_id": doc_id_true,
"doc": doc,
"target": target,
"arguments": [req.args for req in requests],
"resps": [req.resps for req in requests],
"filtered_resps": [
req.filtered_resps[filter_key] for req in requests
],
"filter": filter_key,
"metrics": doc_metrics,
"doc_hash": hash_string(
json.dumps(
requests[0].doc,
indent=2,
default=handle_non_serializable,
ensure_ascii=False,
)
),
"prompt_hash": hash_string(requests[0].arguments[0]),
"target_hash": hash_string(str(target)),
}
example.update(
{
metrics[filter_key][doc_id_true].metric_keys[
0
]: metrics[filter_key][doc_id_true]
}
)
task_output.logged_samples.append(example)
# Process all metrics returned from calculate_metrics
for filter_key in metrics:
for m_samples in metrics[filter_key]:
for metric, value in m_samples:
task_output.sample_metrics[(metric, filter_key)].append(
value
)
else:
# Fall back to the original approach for non-ConfigurableTask instances
indices = (
samples.get(task_output.task_name, None)
if samples is not None
else None
)
doc_iterator = task.doc_iterator(
rank=RANK,
limit=limit,
world_size=WORLD_SIZE,
samples=indices,
)
for doc_id, doc in doc_iterator:
if indices:
doc_id_true = indices[doc_id]
else:
doc_id_true = doc_id
requests = instances_by_doc_id[doc_id]
metrics: list[dict] = [
task.process_results(doc, response)
for req in requests
for response in req.filtered_resps[filter_key]
]
if log_samples:
target = task.doc_to_target(doc)
example = {
"doc_id": doc_id_true,
"doc": doc,
"target": target,
"arguments": [req.args for req in requests],
"resps": [req.resps for req in requests],
"filtered_resps": [
req.filtered_resps[filter_key] for req in requests
],
"filter": filter_key,
"metrics": metrics,
"doc_hash": hash_string(
json.dumps(
requests[0].doc,
indent=2,
default=handle_non_serializable,
ensure_ascii=False,
)
),
"prompt_hash": hash_string(requests[0].arguments[0]),
"target_hash": hash_string(str(target)),
}
example.update({"metrics": metrics})
task_output.logged_samples.append(example)
for x in metrics:
for metric, value in x.items():
task_output.sample_metrics[(metric, filter_key)].append(
value
)
if WORLD_SIZE > 1: if WORLD_SIZE > 1:
# if multigpu, then gather data across all ranks to rank 0 # if multigpu, then gather data across all ranks to rank 0
...@@ -756,8 +641,8 @@ def evaluate( ...@@ -756,8 +641,8 @@ def evaluate(
if RANK == 0: if RANK == 0:
### Aggregate results over all datapoints ### ### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs # aggregate results ; run bootstrap CIs
for task_output in eval_tasks: # for task_output in eval_tasks:
task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters) # task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters)
( (
results, results,
samples, samples,
......
...@@ -20,7 +20,7 @@ class TakeFirstFilter(Filter): ...@@ -20,7 +20,7 @@ class TakeFirstFilter(Filter):
""" """
Assuming each entry of `resps` is a list of model responses, we discard all but the first response. Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
""" """
return map(lambda r: r[0], resps) return map(lambda r: r, resps)
@register_filter("take_first_k") @register_filter("take_first_k")
......
...@@ -17,6 +17,8 @@ import numpy as np ...@@ -17,6 +17,8 @@ import numpy as np
import yaml import yaml
from jinja2 import BaseLoader, Environment, StrictUndefined from jinja2 import BaseLoader, Environment, StrictUndefined
from lm_eval.api.instance import Instance
SPACING = " " * 47 SPACING = " " * 47
...@@ -406,9 +408,13 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False) ...@@ -406,9 +408,13 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False)
v = "%.4f" % v if isinstance(v, float) else v v = "%.4f" % v if isinstance(v, float) else v
if m + "_stderr" + "," + f in dic: if m + "_stderr" + "," + f in dic:
se = dic[m + "_stderr" + "," + f] try:
se = " N/A" if se == "N/A" else "%.4f" % se se = dic[m + "_stderr" + "," + f]
values.append([k, version, f, n, m, hib, v, "±", se]) se = " N/A" if se == "N/A" else "%.4f" % se
values.append([k, version, f, n, m, hib, v, "±", se])
except: # noqa: E722
values.append([k, version, f, n, m, hib, v, "", ""])
else: else:
values.append([k, version, f, n, m, hib, v, "", ""]) values.append([k, version, f, n, m, hib, v, "", ""])
k = "" k = ""
...@@ -550,3 +556,44 @@ def weighted_f1_score(items): ...@@ -550,3 +556,44 @@ def weighted_f1_score(items):
preds = unzipped_list[1] preds = unzipped_list[1]
fscore = f1_score(golds, preds, average="weighted") fscore = f1_score(golds, preds, average="weighted")
return fscore return fscore
def create_sample_log(
doc: dict,
doc_id: int,
target: Any,
requests: list[Instance],
metric_names: [dict],
filter_key: str,
) -> dict:
return {
"doc_id": doc_id,
"doc": doc,
"target": target,
"arguments": [req.args for req in requests],
"resps": [req.resps for req in requests],
"filtered_resps": [req.filtered_resps[filter_key] for req in requests],
"filter": filter_key,
"metrics": metric_names,
"doc_hash": hash_string(
json.dumps(
requests[0].doc,
indent=2,
default=handle_non_serializable,
ensure_ascii=False,
)
),
"prompt_hash": hash_string(requests[0].arguments[0]),
"target_hash": hash_string(str(target)),
}
def pass_at_k(n: int, c: int, k: int) -> float:
"""
:param n: total number of samples
:param c: number of correct samples
:param k: k in pass@$k$
"""
if n - c < k:
return 1.0
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment