Unverified Commit 1ff84897 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Evaluate (#1385)

* un-exclude `evaluate.py` from linting

* readability

* readability

* add task name to build info message

* fix link

* nit

* add functions for var and mean pooling

* add functions for var and mean pooling

* metadata compatibility with task

* rename `override_config` to `set_config` and move to `Task`

* add unit test

* nit

* nit

* bugfix

* nit

* nit

* nit

* add docstrings

* fix metadata-fewshot

* revert metric refactor

* nit

* type checking

* type hints

* type hints

* move `override_metric` to `Task`

* change metadata

* change name

* pre-commit

* rename

* remove

* remove

* `override_metric` backwards compatible with `Task`

* type hints

* use generic

* type hint
parent 1e6825da
import logging
from typing import Callable, Dict
import evaluate
......@@ -75,7 +76,7 @@ def register_group(name):
OUTPUT_TYPE_REGISTRY = {}
METRIC_REGISTRY = {}
METRIC_AGGREGATION_REGISTRY = {}
AGGREGATION_REGISTRY = {}
AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {}
HIGHER_IS_BETTER_REGISTRY = {}
DEFAULT_METRIC_REGISTRY = {
......@@ -118,7 +119,7 @@ def register_metric(**args):
return decorate
def get_metric(name, hf_evaluate_metric=False):
def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
if not hf_evaluate_metric:
if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name]
......@@ -136,7 +137,7 @@ def get_metric(name, hf_evaluate_metric=False):
)
def register_aggregation(name):
def register_aggregation(name: str):
def decorate(fn):
assert (
name not in AGGREGATION_REGISTRY
......@@ -148,21 +149,21 @@ def register_aggregation(name):
return decorate
def get_aggregation(name):
def get_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(f"{name} not a registered aggregation metric!")
def get_metric_aggregation(name):
def get_metric_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
try:
return METRIC_AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
def is_higher_better(metric_name):
def is_higher_better(metric_name) -> bool:
try:
return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError:
......
......@@ -357,7 +357,7 @@ class Task(abc.ABC):
else:
assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
eval_logger.info(f"Building contexts for task on rank {rank}...")
eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...")
instances = []
for doc_id, doc in utils.create_iterator(
......@@ -511,6 +511,7 @@ class Task(abc.ABC):
return description + labeled_examples + example
def apply_filters(self):
"""Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"):
for f in self._filters:
f.apply(self._instances)
......@@ -519,15 +520,51 @@ class Task(abc.ABC):
return self._instances
def dump_config(self) -> dict:
"""Returns a dictionary representing the task's config.
:returns: str
The fewshot context.
"""
"""Returns the config as a dictionary."""
# TODO: this should only return the overrides applied to a non-YAML task's configuration.
# (num_fewshot)
return self.config.to_dict()
def set_config(self, key: str, value: Any, update: bool = False) -> None:
"""Set or update the configuration for a given key."""
if key is None:
raise ValueError("Key must be provided.")
if update:
current_value = getattr(self._config, key, {})
if not isinstance(current_value, dict):
raise TypeError(
f"Expected a dict for key '{key}', got {type(current_value).__name__} instead."
)
current_value.update(value)
else:
setattr(self._config, key, value)
def override_metric(self, metric_name: str) -> None:
"""
Override the default metrics used for evaluation with custom metrics.
Parameters:
- metric_name (str): The name of the custom metric to override. Should be registered in api.metrics.
"""
(
self._metric_fn_list,
self._aggregation_list,
self._metric_fn_kwargs,
self._higher_is_better,
) = ({}, {}, {}, {})
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._aggregation_list[metric_name] = get_metric_aggregation(metric_name)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
self._metric_fn_kwargs[metric_name] = {}
if not isinstance(self, ConfigurableTask):
self.process_results = lambda x, y: {metric_name: get_metric(metric_name)}
self.aggregation = lambda: {
metric_name: get_metric_aggregation(metric_name)
}
setattr(self._config, "metric_list", [{"metric": metric_name}])
setattr(self._config, "process_results", None)
class ConfigurableTask(Task):
VERSION = "Yaml"
......@@ -833,6 +870,7 @@ class ConfigurableTask(Task):
return labeled_examples + str(example)
def apply_filters(self):
"""Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"):
for f in self._filters:
f.apply(self._instances)
......@@ -1222,37 +1260,6 @@ class ConfigurableTask(Task):
def get_config(self, key: str) -> Any:
return getattr(self._config, key, None)
def override_metric(self, metric_name: str) -> None:
"""
Override the default metrics used for evaluation with custom metrics.
Parameters:
- metric_name (str): The name of the custom metric to override. Should be registered in api.metrics.
"""
(
self._metric_fn_list,
self._aggregation_list,
self._metric_fn_kwargs,
self._higher_is_better,
) = ({}, {}, {}, {})
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._aggregation_list[metric_name] = get_metric_aggregation(metric_name)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
self._metric_fn_kwargs[metric_name] = {}
setattr(self._config, "metric_list", [{"metric": metric_name}])
setattr(self._config, "process_results", None)
def override_config(
self, key: str = None, value: Any = None, update: bool = False
) -> None:
if update:
current_value = getattr(self._config, key)
assert isinstance(current_value, dict)
current_value.update(value)
setattr(self._config, key, current_value)
else:
setattr(self._config, key, value)
class MultipleChoiceTask(Task):
OUTPUT_TYPE: str = "loglikelihood"
......
import random
import itertools
import collections
import torch
import itertools
import logging
import random
from typing import Optional, Union
import numpy as np
import torch
import lm_eval.api
import lm_eval.models
import lm_eval.api.metrics
import lm_eval.api.registry
from lm_eval.tasks import (
get_task_dict,
TaskManager
)
import lm_eval.models
from lm_eval.tasks import TaskManager, get_task_dict
from lm_eval.utils import (
eval_logger,
get_git_commit_hash,
positional_deprecated,
run_task_tests,
get_git_commit_hash,
simple_parse_args_string,
eval_logger
)
@positional_deprecated
def simple_evaluate(
model,
model_args=None,
model_args: Optional[str] = None,
tasks=None,
num_fewshot=None,
batch_size=None,
max_batch_size=None,
device=None,
use_cache=None,
limit=None,
num_fewshot: Optional[int] = None,
batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
device: Optional[str] = None,
use_cache: Optional[str] = None,
limit: Optional[Union[int, float]] = None,
bootstrap_iters: int = 100000,
check_integrity: bool = False,
decontamination_ngrams_path=None,
......@@ -138,8 +133,8 @@ def simple_evaluate(
eval_logger.info(
"get_task_dict has been updated to accept an optional argument, `task_manager`"
"Read more here: https://github.com/EleutherAI/lm-evaluation-harness/blob/recursive-groups/docs/interface.md#external-library-usage"
)
"Read more here:https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md#external-library-usage"
)
task_dict = get_task_dict(tasks, task_manager)
for task_name in task_dict.keys():
task_obj = task_dict[task_name]
......@@ -150,7 +145,7 @@ def simple_evaluate(
if task_obj.get_config("output_type") == "generate_until":
if gen_kwargs is not None:
task_obj.override_config(
task_obj.set_config(
key="generation_kwargs", value=gen_kwargs, update=True
)
......@@ -171,7 +166,7 @@ def simple_evaluate(
eval_logger.warning(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
)
task_obj.override_config(key="num_fewshot", value=num_fewshot)
task_obj.set_config(key="num_fewshot", value=num_fewshot)
if check_integrity:
run_task_tests(task_list=tasks)
......@@ -222,8 +217,8 @@ decontaminate_suffix = "_decontaminate"
def evaluate(
lm,
task_dict,
limit=None,
bootstrap_iters: int = 100000,
limit: Optional[int] = None,
bootstrap_iters: Optional[int] = 100000,
decontamination_ngrams_path=None,
write_out: bool = False,
log_samples: bool = True,
......@@ -297,13 +292,9 @@ def evaluate(
versions[task_name] = task.VERSION
configs[task_name] = dict(task.dump_config())
if "num_fewshot" in configs[task_name]:
if configs[task_name]["metadata"]:
n_shot = configs[task_name]["metadata"].get("num_fewshot", None)
if not n_shot:
n_shot = configs[task_name]["num_fewshot"]
else:
n_shot = 0 # TODO: is this always right?
# Number of few-shots for printing.
if (n_shot := configs[task_name].get("num_fewshot")) == 0:
n_shot = configs[task_name].get("metadata", {}).get("num_fewshot", 0)
num_fewshot[task_name] = n_shot
if "task_alias" in configs[task_name]:
......@@ -483,36 +474,31 @@ def evaluate(
vals = vals_torch
if lm.rank == 0:
### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items():
task = task_dict[task_name]
metric_key = metric + "," + key
if isinstance(task, tuple):
group_name, task = task
else:
group_name = None
group_name, task = task if isinstance(task, tuple) else (None, task)
metric_key = f"{metric},{key}"
agg_fn = task.aggregation()[metric]
results[task_name][metric_key] = agg_fn(items)
results[task_name]["samples"] = len(items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
if bootstrap_iters > 0:
stderr = lm_eval.api.metrics.stderr_for_metric(
metric=task.aggregation()[metric],
stderr_fn = lm_eval.api.metrics.stderr_for_metric(
metric=agg_fn,
bootstrap_iters=min(bootstrap_iters, 100)
if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters,
)
if stderr is not None and len(items) > 1:
results[task_name][metric + "_stderr" + "," + key] = stderr(items)
else:
results[task_name][metric + "_stderr" + "," + key] = "N/A"
results[task_name][f"{metric}_stderr,{key}"] = (
stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A"
)
if bool(results):
for group, task_list in reversed(task_hierarchy.items()):
......@@ -523,22 +509,30 @@ def evaluate(
# we only want to operate on groups here.
continue
for metric in [
key for key in results[task_list[0]].keys() if "_stderr" not in key and key not in ["alias", "samples"]
]: # TODO: what if tasks don't all share the same metrics
key
for key in results[task_list[0]].keys()
if "_stderr" not in key and key not in ["alias", "samples"]
]: # TODO: what if tasks don't all share the same metrics
stderr = "_stderr,".join(metric.split(","))
# gather metrics, sizes, and stderrs from subtasks
metrics = [results[task][metric] for task in task_list] # TODO: copy?
metrics = [
results[task][metric] for task in task_list
] # TODO: copy?
stderrs = [results[task][stderr] for task in task_list]
sizes = [results[task]["samples"] for task in task_list]
# compute group's pooled metric and stderr
results[group][metric] = lm_eval.api.metrics.aggregate_subtask_metrics(metrics, sizes)
results[group][
metric
] = lm_eval.api.metrics.aggregate_subtask_metrics(metrics, sizes)
# TODO: calculate grouped metric using aggregation fn
if "N/A" in stderrs:
results[group][stderr] = "N/A"
else:
results[group][stderr] = lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes)
results[group][
stderr
] = lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes)
# TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
# To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
# results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
......@@ -618,8 +612,10 @@ def evaluate(
groups_agg = {**groups_agg, **_groups_agg}
for group_name, task_list in task_hierarchy.items():
if task_list != []:
num_fewshot[group_name] = num_fewshot[task_list[0]] # TODO: validate this
if task_list:
num_fewshot[group_name] = num_fewshot[
task_list[0]
] # TODO: validate this
results_dict = {
"results": dict(results_agg.items()),
......
......@@ -88,7 +88,7 @@ all = [
]
[tool.ruff]
extend-exclude = ["lm_eval/evaluator.py", "lm_eval/tasks/*.py"]
extend-exclude = ["lm_eval/tasks/*.py"]
[tool.ruff.lint]
extend-select = ["I"]
......
import pytest
from lm_eval.utils import Collator, get_rolling_token_windows, make_disjoint_window
from lm_eval.utils import (
Collator,
get_rolling_token_windows,
make_disjoint_window,
)
# noinspection DuplicatedCode
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment