Unverified Commit 1ff84897 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Evaluate (#1385)

* un-exclude `evaluate.py` from linting

* readability

* readability

* add task name to build info message

* fix link

* nit

* add functions for var and mean pooling

* add functions for var and mean pooling

* metadata compatibility with task

* rename `override_config` to `set_config` and move to `Task`

* add unit test

* nit

* nit

* bugfix

* nit

* nit

* nit

* add docstrings

* fix metadata-fewshot

* revert metric refactor

* nit

* type checking

* type hints

* type hints

* move `override_metric` to `Task`

* change metadata

* change name

* pre-commit

* rename

* remove

* remove

* `override_metric` backwards compatible with `Task`

* type hints

* use generic

* type hint
parent 1e6825da
import logging import logging
from typing import Callable, Dict
import evaluate import evaluate
...@@ -75,7 +76,7 @@ def register_group(name): ...@@ -75,7 +76,7 @@ def register_group(name):
OUTPUT_TYPE_REGISTRY = {} OUTPUT_TYPE_REGISTRY = {}
METRIC_REGISTRY = {} METRIC_REGISTRY = {}
METRIC_AGGREGATION_REGISTRY = {} METRIC_AGGREGATION_REGISTRY = {}
AGGREGATION_REGISTRY = {} AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {}
HIGHER_IS_BETTER_REGISTRY = {} HIGHER_IS_BETTER_REGISTRY = {}
DEFAULT_METRIC_REGISTRY = { DEFAULT_METRIC_REGISTRY = {
...@@ -118,7 +119,7 @@ def register_metric(**args): ...@@ -118,7 +119,7 @@ def register_metric(**args):
return decorate return decorate
def get_metric(name, hf_evaluate_metric=False): def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
if not hf_evaluate_metric: if not hf_evaluate_metric:
if name in METRIC_REGISTRY: if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name] return METRIC_REGISTRY[name]
...@@ -136,7 +137,7 @@ def get_metric(name, hf_evaluate_metric=False): ...@@ -136,7 +137,7 @@ def get_metric(name, hf_evaluate_metric=False):
) )
def register_aggregation(name): def register_aggregation(name: str):
def decorate(fn): def decorate(fn):
assert ( assert (
name not in AGGREGATION_REGISTRY name not in AGGREGATION_REGISTRY
...@@ -148,21 +149,21 @@ def register_aggregation(name): ...@@ -148,21 +149,21 @@ def register_aggregation(name):
return decorate return decorate
def get_aggregation(name): def get_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
try: try:
return AGGREGATION_REGISTRY[name] return AGGREGATION_REGISTRY[name]
except KeyError: except KeyError:
eval_logger.warning(f"{name} not a registered aggregation metric!") eval_logger.warning(f"{name} not a registered aggregation metric!")
def get_metric_aggregation(name): def get_metric_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
try: try:
return METRIC_AGGREGATION_REGISTRY[name] return METRIC_AGGREGATION_REGISTRY[name]
except KeyError: except KeyError:
eval_logger.warning(f"{name} metric is not assigned a default aggregation!") eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
def is_higher_better(metric_name): def is_higher_better(metric_name) -> bool:
try: try:
return HIGHER_IS_BETTER_REGISTRY[metric_name] return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError: except KeyError:
......
...@@ -357,7 +357,7 @@ class Task(abc.ABC): ...@@ -357,7 +357,7 @@ class Task(abc.ABC):
else: else:
assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!" assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
eval_logger.info(f"Building contexts for task on rank {rank}...") eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...")
instances = [] instances = []
for doc_id, doc in utils.create_iterator( for doc_id, doc in utils.create_iterator(
...@@ -511,6 +511,7 @@ class Task(abc.ABC): ...@@ -511,6 +511,7 @@ class Task(abc.ABC):
return description + labeled_examples + example return description + labeled_examples + example
def apply_filters(self): def apply_filters(self):
"""Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"): if hasattr(self, "_filters"):
for f in self._filters: for f in self._filters:
f.apply(self._instances) f.apply(self._instances)
...@@ -519,15 +520,51 @@ class Task(abc.ABC): ...@@ -519,15 +520,51 @@ class Task(abc.ABC):
return self._instances return self._instances
def dump_config(self) -> dict: def dump_config(self) -> dict:
"""Returns a dictionary representing the task's config. """Returns the config as a dictionary."""
:returns: str
The fewshot context.
"""
# TODO: this should only return the overrides applied to a non-YAML task's configuration. # TODO: this should only return the overrides applied to a non-YAML task's configuration.
# (num_fewshot) # (num_fewshot)
return self.config.to_dict() return self.config.to_dict()
def set_config(self, key: str, value: Any, update: bool = False) -> None:
"""Set or update the configuration for a given key."""
if key is None:
raise ValueError("Key must be provided.")
if update:
current_value = getattr(self._config, key, {})
if not isinstance(current_value, dict):
raise TypeError(
f"Expected a dict for key '{key}', got {type(current_value).__name__} instead."
)
current_value.update(value)
else:
setattr(self._config, key, value)
def override_metric(self, metric_name: str) -> None:
"""
Override the default metrics used for evaluation with custom metrics.
Parameters:
- metric_name (str): The name of the custom metric to override. Should be registered in api.metrics.
"""
(
self._metric_fn_list,
self._aggregation_list,
self._metric_fn_kwargs,
self._higher_is_better,
) = ({}, {}, {}, {})
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._aggregation_list[metric_name] = get_metric_aggregation(metric_name)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
self._metric_fn_kwargs[metric_name] = {}
if not isinstance(self, ConfigurableTask):
self.process_results = lambda x, y: {metric_name: get_metric(metric_name)}
self.aggregation = lambda: {
metric_name: get_metric_aggregation(metric_name)
}
setattr(self._config, "metric_list", [{"metric": metric_name}])
setattr(self._config, "process_results", None)
class ConfigurableTask(Task): class ConfigurableTask(Task):
VERSION = "Yaml" VERSION = "Yaml"
...@@ -833,6 +870,7 @@ class ConfigurableTask(Task): ...@@ -833,6 +870,7 @@ class ConfigurableTask(Task):
return labeled_examples + str(example) return labeled_examples + str(example)
def apply_filters(self): def apply_filters(self):
"""Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"): if hasattr(self, "_filters"):
for f in self._filters: for f in self._filters:
f.apply(self._instances) f.apply(self._instances)
...@@ -1222,37 +1260,6 @@ class ConfigurableTask(Task): ...@@ -1222,37 +1260,6 @@ class ConfigurableTask(Task):
def get_config(self, key: str) -> Any: def get_config(self, key: str) -> Any:
return getattr(self._config, key, None) return getattr(self._config, key, None)
def override_metric(self, metric_name: str) -> None:
"""
Override the default metrics used for evaluation with custom metrics.
Parameters:
- metric_name (str): The name of the custom metric to override. Should be registered in api.metrics.
"""
(
self._metric_fn_list,
self._aggregation_list,
self._metric_fn_kwargs,
self._higher_is_better,
) = ({}, {}, {}, {})
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._aggregation_list[metric_name] = get_metric_aggregation(metric_name)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
self._metric_fn_kwargs[metric_name] = {}
setattr(self._config, "metric_list", [{"metric": metric_name}])
setattr(self._config, "process_results", None)
def override_config(
self, key: str = None, value: Any = None, update: bool = False
) -> None:
if update:
current_value = getattr(self._config, key)
assert isinstance(current_value, dict)
current_value.update(value)
setattr(self._config, key, current_value)
else:
setattr(self._config, key, value)
class MultipleChoiceTask(Task): class MultipleChoiceTask(Task):
OUTPUT_TYPE: str = "loglikelihood" OUTPUT_TYPE: str = "loglikelihood"
......
import random
import itertools
import collections import collections
import itertools
import torch
import logging import logging
import random
from typing import Optional, Union
import numpy as np import numpy as np
import torch
import lm_eval.api
import lm_eval.models
import lm_eval.api.metrics import lm_eval.api.metrics
import lm_eval.api.registry import lm_eval.api.registry
import lm_eval.models
from lm_eval.tasks import ( from lm_eval.tasks import TaskManager, get_task_dict
get_task_dict,
TaskManager
)
from lm_eval.utils import ( from lm_eval.utils import (
eval_logger,
get_git_commit_hash,
positional_deprecated, positional_deprecated,
run_task_tests, run_task_tests,
get_git_commit_hash,
simple_parse_args_string, simple_parse_args_string,
eval_logger
) )
@positional_deprecated @positional_deprecated
def simple_evaluate( def simple_evaluate(
model, model,
model_args=None, model_args: Optional[str] = None,
tasks=None, tasks=None,
num_fewshot=None, num_fewshot: Optional[int] = None,
batch_size=None, batch_size: Optional[int] = None,
max_batch_size=None, max_batch_size: Optional[int] = None,
device=None, device: Optional[str] = None,
use_cache=None, use_cache: Optional[str] = None,
limit=None, limit: Optional[Union[int, float]] = None,
bootstrap_iters: int = 100000, bootstrap_iters: int = 100000,
check_integrity: bool = False, check_integrity: bool = False,
decontamination_ngrams_path=None, decontamination_ngrams_path=None,
...@@ -138,8 +133,8 @@ def simple_evaluate( ...@@ -138,8 +133,8 @@ def simple_evaluate(
eval_logger.info( eval_logger.info(
"get_task_dict has been updated to accept an optional argument, `task_manager`" "get_task_dict has been updated to accept an optional argument, `task_manager`"
"Read more here: https://github.com/EleutherAI/lm-evaluation-harness/blob/recursive-groups/docs/interface.md#external-library-usage" "Read more here:https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md#external-library-usage"
) )
task_dict = get_task_dict(tasks, task_manager) task_dict = get_task_dict(tasks, task_manager)
for task_name in task_dict.keys(): for task_name in task_dict.keys():
task_obj = task_dict[task_name] task_obj = task_dict[task_name]
...@@ -150,7 +145,7 @@ def simple_evaluate( ...@@ -150,7 +145,7 @@ def simple_evaluate(
if task_obj.get_config("output_type") == "generate_until": if task_obj.get_config("output_type") == "generate_until":
if gen_kwargs is not None: if gen_kwargs is not None:
task_obj.override_config( task_obj.set_config(
key="generation_kwargs", value=gen_kwargs, update=True key="generation_kwargs", value=gen_kwargs, update=True
) )
...@@ -171,7 +166,7 @@ def simple_evaluate( ...@@ -171,7 +166,7 @@ def simple_evaluate(
eval_logger.warning( eval_logger.warning(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}" f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
) )
task_obj.override_config(key="num_fewshot", value=num_fewshot) task_obj.set_config(key="num_fewshot", value=num_fewshot)
if check_integrity: if check_integrity:
run_task_tests(task_list=tasks) run_task_tests(task_list=tasks)
...@@ -222,8 +217,8 @@ decontaminate_suffix = "_decontaminate" ...@@ -222,8 +217,8 @@ decontaminate_suffix = "_decontaminate"
def evaluate( def evaluate(
lm, lm,
task_dict, task_dict,
limit=None, limit: Optional[int] = None,
bootstrap_iters: int = 100000, bootstrap_iters: Optional[int] = 100000,
decontamination_ngrams_path=None, decontamination_ngrams_path=None,
write_out: bool = False, write_out: bool = False,
log_samples: bool = True, log_samples: bool = True,
...@@ -297,13 +292,9 @@ def evaluate( ...@@ -297,13 +292,9 @@ def evaluate(
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
configs[task_name] = dict(task.dump_config()) configs[task_name] = dict(task.dump_config())
if "num_fewshot" in configs[task_name]: # Number of few-shots for printing.
if configs[task_name]["metadata"]: if (n_shot := configs[task_name].get("num_fewshot")) == 0:
n_shot = configs[task_name]["metadata"].get("num_fewshot", None) n_shot = configs[task_name].get("metadata", {}).get("num_fewshot", 0)
if not n_shot:
n_shot = configs[task_name]["num_fewshot"]
else:
n_shot = 0 # TODO: is this always right?
num_fewshot[task_name] = n_shot num_fewshot[task_name] = n_shot
if "task_alias" in configs[task_name]: if "task_alias" in configs[task_name]:
...@@ -483,36 +474,31 @@ def evaluate( ...@@ -483,36 +474,31 @@ def evaluate(
vals = vals_torch vals = vals_torch
if lm.rank == 0: if lm.rank == 0:
### Aggregate results over all datapoints ### ### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs # aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items(): for (task_name, key, metric), items in vals.items():
task = task_dict[task_name] task = task_dict[task_name]
metric_key = metric + "," + key group_name, task = task if isinstance(task, tuple) else (None, task)
if isinstance(task, tuple):
group_name, task = task
else:
group_name = None
metric_key = f"{metric},{key}"
agg_fn = task.aggregation()[metric] agg_fn = task.aggregation()[metric]
results[task_name][metric_key] = agg_fn(items) results[task_name][metric_key] = agg_fn(items)
results[task_name]["samples"] = len(items) results[task_name]["samples"] = len(items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
if bootstrap_iters > 0: if bootstrap_iters > 0:
stderr = lm_eval.api.metrics.stderr_for_metric( stderr_fn = lm_eval.api.metrics.stderr_for_metric(
metric=task.aggregation()[metric], metric=agg_fn,
bootstrap_iters=min(bootstrap_iters, 100) bootstrap_iters=min(bootstrap_iters, 100)
if metric in ["bleu", "chrf", "ter"] if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters, else bootstrap_iters,
) )
if stderr is not None and len(items) > 1: results[task_name][f"{metric}_stderr,{key}"] = (
results[task_name][metric + "_stderr" + "," + key] = stderr(items) stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A"
else: )
results[task_name][metric + "_stderr" + "," + key] = "N/A"
if bool(results): if bool(results):
for group, task_list in reversed(task_hierarchy.items()): for group, task_list in reversed(task_hierarchy.items()):
...@@ -523,22 +509,30 @@ def evaluate( ...@@ -523,22 +509,30 @@ def evaluate(
# we only want to operate on groups here. # we only want to operate on groups here.
continue continue
for metric in [ for metric in [
key for key in results[task_list[0]].keys() if "_stderr" not in key and key not in ["alias", "samples"] key
]: # TODO: what if tasks don't all share the same metrics for key in results[task_list[0]].keys()
if "_stderr" not in key and key not in ["alias", "samples"]
]: # TODO: what if tasks don't all share the same metrics
stderr = "_stderr,".join(metric.split(",")) stderr = "_stderr,".join(metric.split(","))
# gather metrics, sizes, and stderrs from subtasks # gather metrics, sizes, and stderrs from subtasks
metrics = [results[task][metric] for task in task_list] # TODO: copy? metrics = [
results[task][metric] for task in task_list
] # TODO: copy?
stderrs = [results[task][stderr] for task in task_list] stderrs = [results[task][stderr] for task in task_list]
sizes = [results[task]["samples"] for task in task_list] sizes = [results[task]["samples"] for task in task_list]
# compute group's pooled metric and stderr # compute group's pooled metric and stderr
results[group][metric] = lm_eval.api.metrics.aggregate_subtask_metrics(metrics, sizes) results[group][
metric
] = lm_eval.api.metrics.aggregate_subtask_metrics(metrics, sizes)
# TODO: calculate grouped metric using aggregation fn # TODO: calculate grouped metric using aggregation fn
if "N/A" in stderrs: if "N/A" in stderrs:
results[group][stderr] = "N/A" results[group][stderr] = "N/A"
else: else:
results[group][stderr] = lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes) results[group][
stderr
] = lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes)
# TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility # TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
# To use the old (likely incorrect) variance formula, comment out the above and uncomment this line: # To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
# results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics) # results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
...@@ -618,8 +612,10 @@ def evaluate( ...@@ -618,8 +612,10 @@ def evaluate(
groups_agg = {**groups_agg, **_groups_agg} groups_agg = {**groups_agg, **_groups_agg}
for group_name, task_list in task_hierarchy.items(): for group_name, task_list in task_hierarchy.items():
if task_list != []: if task_list:
num_fewshot[group_name] = num_fewshot[task_list[0]] # TODO: validate this num_fewshot[group_name] = num_fewshot[
task_list[0]
] # TODO: validate this
results_dict = { results_dict = {
"results": dict(results_agg.items()), "results": dict(results_agg.items()),
......
...@@ -88,7 +88,7 @@ all = [ ...@@ -88,7 +88,7 @@ all = [
] ]
[tool.ruff] [tool.ruff]
extend-exclude = ["lm_eval/evaluator.py", "lm_eval/tasks/*.py"] extend-exclude = ["lm_eval/tasks/*.py"]
[tool.ruff.lint] [tool.ruff.lint]
extend-select = ["I"] extend-select = ["I"]
......
import pytest import pytest
from lm_eval.utils import Collator, get_rolling_token_windows, make_disjoint_window from lm_eval.utils import (
Collator,
get_rolling_token_windows,
make_disjoint_window,
)
# noinspection DuplicatedCode # noinspection DuplicatedCode
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment