Unverified Commit 5ccd65d4 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Refactor `evaluater.evaluate` (#1441)



* change `all_gather` to `gather`

* add TaskOutput utility class

* Add FilterResults class and refactor task handling.

* Rename `key` to `filter_key` for clarity

* Add `print_writeout` function in utils.py

* Add function to calculate limit size.

* Add doc_iterator method to Task class

* Refactor `doc_iterator` and cleanup in Task class

* remove superfluous bits

* change `all_gather` to `gather`

* bugfix

* bugfix

* fix `gather`

* Refactor `gather` loop

* Refactor aggregate metrics calculation

* Refactor and simplify aggregate metrics calculation
Removed unused code

* Simplify metrics calculation and remove unused code.

* simplify the metrics calculation in `utils.py` and `evaluator.py`.

* Fix group metric

* change evaluate to hf_evaluate

* change evaluate to hf_evaluate

* add docs

* add docs

* nits

* make isslice keyword only

* nit

* add todo

* nit

* nit

* nit: swap order samples_metrics tuple

* move instance sorting outside loop

* nit

* nit

* Add __repr__ for ConfigurableTask

* nit

* nit

* Revert "nit"

This reverts commit dab8d9977a643752a17f840fd8cf7e4b107df28f.

* fix some logging

* nit

* fix `predict_only` bug. thanks to `@LSinev`!

* change `print_tasks` to `prepare_print_tasks`

* nits

* move eval utils

* move eval utils

* nit

* add comment

* added tqdm descriptions

* Update lm_eval/evaluator_utils.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* fix mgsm bug

* nit

* fix `build_all_requests`

* pre-commit

* add ceil to limit

---------
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent 96d185fa
...@@ -225,7 +225,7 @@ class CachingLM: ...@@ -225,7 +225,7 @@ class CachingLM:
eval_logger.info( eval_logger.info(
f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..." f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..."
) )
for req in tqdm(requests): for req in tqdm(requests, desc="Checking cached requests"):
hsh = hash_args(attr, req.args) hsh = hash_args(attr, req.args)
if attr == "generate_until" and req.args[1].get("do_sample", False): if attr == "generate_until" and req.args[1].get("do_sample", False):
# when we are doing non-greedy generation, don't use the cache # when we are doing non-greedy generation, don't use the cache
...@@ -246,7 +246,9 @@ class CachingLM: ...@@ -246,7 +246,9 @@ class CachingLM:
else: else:
res.append(None) res.append(None)
remaining_reqs.append(req) remaining_reqs.append(req)
eval_logger.info(
f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
)
# actually run the LM on the requests that do not have cached results # actually run the LM on the requests that do not have cached results
rem_res = getattr(self.lm, attr)(remaining_reqs) rem_res = getattr(self.lm, attr)(remaining_reqs)
......
...@@ -7,7 +7,7 @@ from collections.abc import Callable ...@@ -7,7 +7,7 @@ from collections.abc import Callable
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from inspect import getsource from inspect import getsource
from typing import Any, List, Literal, Tuple, Union from typing import Any, Iterator, List, Literal, Tuple, Union
import datasets import datasets
import numpy as np import numpy as np
...@@ -327,7 +327,7 @@ class Task(abc.ABC): ...@@ -327,7 +327,7 @@ class Task(abc.ABC):
return doc return doc
@property @property
def instances(self): def instances(self) -> List[Instance]:
"""After calling `task.build_all_requests()`, tasks """After calling `task.build_all_requests()`, tasks
maintain a list of the dataset instances which will be evaluated. maintain a list of the dataset instances which will be evaluated.
""" """
...@@ -355,6 +355,7 @@ class Task(abc.ABC): ...@@ -355,6 +355,7 @@ class Task(abc.ABC):
def build_all_requests( def build_all_requests(
self, self,
*,
limit=None, limit=None,
rank=None, rank=None,
world_size=None, world_size=None,
...@@ -382,13 +383,6 @@ class Task(abc.ABC): ...@@ -382,13 +383,6 @@ class Task(abc.ABC):
self._instances = flattened_instances self._instances = flattened_instances
return return
if self.has_test_docs():
docs = self.test_docs()
elif self.has_validation_docs():
docs = self.validation_docs()
else:
assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...") eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...")
instances = [] instances = []
...@@ -402,12 +396,7 @@ class Task(abc.ABC): ...@@ -402,12 +396,7 @@ class Task(abc.ABC):
limit = None limit = None
doc_id_docs = list( doc_id_docs = list(
utils.create_iterator( self.doc_iterator(rank=rank, limit=limit, world_size=world_size)
enumerate(docs),
rank,
world_size,
limit,
)
) )
num_docs = len(doc_id_docs) num_docs = len(doc_id_docs)
...@@ -632,6 +621,27 @@ class Task(abc.ABC): ...@@ -632,6 +621,27 @@ class Task(abc.ABC):
setattr(self._config, "metric_list", [{"metric": metric_name}]) setattr(self._config, "metric_list", [{"metric": metric_name}])
setattr(self._config, "process_results", None) setattr(self._config, "process_results", None)
@property
def eval_docs(self) -> Union[datasets.Dataset, List[dict]]:
if self.has_test_docs():
return self.test_docs()
elif self.has_validation_docs():
return self.validation_docs()
else:
assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
def doc_iterator(
self, *, rank: int = 0, limit: Union[int, None] = None, world_size: int = 1
) -> Iterator[Tuple[int, Any]]:
limit = int(limit) if limit else None
doc_iterator = utils.create_iterator(
enumerate(self.eval_docs),
rank=int(rank),
limit=limit,
world_size=int(world_size),
)
return doc_iterator
class ConfigurableTask(Task): class ConfigurableTask(Task):
VERSION = "Yaml" VERSION = "Yaml"
...@@ -781,12 +791,7 @@ class ConfigurableTask(Task): ...@@ -781,12 +791,7 @@ class ConfigurableTask(Task):
else "default" else "default"
)(list(self.fewshot_docs()), self, rnd=random.Random(1234)) )(list(self.fewshot_docs()), self, rnd=random.Random(1234))
if self.has_test_docs(): self.task_docs = self.eval_docs
self.task_docs = self.test_docs()
elif self.has_validation_docs():
self.task_docs = self.validation_docs()
else:
assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
# Test One Doc # Test One Doc
self.features = list(self.task_docs.features.keys()) self.features = list(self.task_docs.features.keys())
...@@ -1336,6 +1341,15 @@ class ConfigurableTask(Task): ...@@ -1336,6 +1341,15 @@ class ConfigurableTask(Task):
def get_config(self, key: str) -> Any: def get_config(self, key: str) -> Any:
return getattr(self._config, key, None) return getattr(self._config, key, None)
def __repr__(self):
return (
f"ConfigurableTask(task_name={getattr(self.config, 'task', None)},"
f"group_name={getattr(self.config, 'group', None)},"
f"output_type={self.OUTPUT_TYPE},"
f"num_fewshot={getattr(self.config, 'num_fewshot', None)},"
f"num_samples={len(self.eval_docs)})"
)
class MultipleChoiceTask(Task): class MultipleChoiceTask(Task):
OUTPUT_TYPE: str = "loglikelihood" OUTPUT_TYPE: str = "loglikelihood"
......
This diff is collapsed.
import collections
import math
import pathlib
import sys
from typing import Dict, List, Optional, Tuple, Union
from lm_eval.api import metrics
from lm_eval.utils import eval_logger, positional_deprecated
class TaskOutput:
"""
Wrapper class for Task outputs.It contains various attributes and methods to manage and calculate metrics for the task.
Attributes:
task (object): The task object.
task_name (str): The name of the task.
task_config (dict): The configuration of the task.
version (str): The version of the task.
group_name (str): The name of the task group.
n_shot (int): The number of shots for the task.
task_alias (str): The alias of the task.
group_alias (str): The alias of the task group.
is_group (bool): Indicates if the task is a group.
logged_samples (list): The list of logged samples.
sample_len (int): The length of the samples.
sample_metrics (defaultdict): The dictionary of samples' metrics.
agg_metrics (defaultdict): The dictionary of aggregate metrics.
Methods:
from_taskdict(cls, task_name: str, task):
Creates a TaskOutput instance from a task dictionary.
calculate_aggregate_metric(bootstrap_iters=100000) -> None:
Calculates the aggregate metrics for the task.
"""
def __init__(
self,
task=None,
task_name=None,
task_config=None,
version=None,
group_name=None,
n_shot=None,
task_alias=None,
group_alias=None,
is_group=None,
):
self.task = task
self.task_config = task_config
self.task_name = task_name
self.group_name = group_name
self.version = version
self.n_shot = n_shot
self.task_alias = task_alias
self.group_alias = group_alias
self.is_group = is_group
self.logged_samples = []
self.sample_len = None
self.sample_metrics = collections.defaultdict(list)
self.agg_metrics = collections.defaultdict(list)
@classmethod
def from_taskdict(cls, task_name: str, task):
if isinstance(task, tuple):
group_name, task = task
else:
group_name = None
if not task:
# these gets filtered out in get_task_list
# once they are added to group hierarchy
is_group = True
return cls(
task=task, task_name=task_name, is_group=is_group, group_name=group_name
)
version = task.VERSION
task_config = dict(task.dump_config())
if (n_shot := task_config.get("num_fewshot")) == 0:
n_shot = task_config.get("metadata", {}).get("num_fewshot", 0)
task_alias = task_config.get("alias")
group_alias = task_config.get("group_alias")
return cls(
task=task,
task_name=task_name,
task_config=task_config,
group_name=group_name,
version=version,
n_shot=n_shot,
task_alias=task_alias,
group_alias=group_alias,
)
def calculate_aggregate_metric(self, bootstrap_iters=100000) -> None:
for (metric, filter_key), items in self.sample_metrics.items():
agg_fn = self.task.aggregation()[metric]
metric_key = f"{metric},{filter_key}"
self.agg_metrics[metric_key] = agg_fn(items)
self.sample_len = len(items) # TODO: same sample size for each metric?
if bootstrap_iters:
stderr_fn = metrics.stderr_for_metric(
metric=agg_fn,
bootstrap_iters=min(bootstrap_iters, 100)
if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters,
)
self.agg_metrics[f"{metric}_stderr,{filter_key}"] = (
stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A"
)
def __repr__(self):
return (
f"TaskOutput(task_name={self.task_name}, "
f"group_name={self.group_name}, "
f"version={self.version},"
f"n_shot={self.n_shot}"
f"task_alias={self.task_alias}, group_alias={self.group_alias})"
)
def get_task_list(task_dict: dict) -> Tuple[Dict[str, list], List[TaskOutput]]:
task_hierarchy = collections.defaultdict(list)
outputs = list(TaskOutput.from_taskdict(x, y) for x, y in task_dict.items())
for task_output in outputs:
if group_name := task_output.group_name:
task_hierarchy[group_name].append(task_output.task_name)
else:
task_hierarchy[task_output.task_name] = []
# returns task_hierarchy tracking which groups contain which subtasks,
# and a list of TaskOutput classes for each non-group subtask
return task_hierarchy, [x for x in outputs if x.task]
def print_writeout(task) -> None:
for inst in task.instances:
# print the prompt for the first few documents
if inst.doc_id < 1:
eval_logger.info(
f"Task: {task}; document {inst.doc_id}; context prompt (starting on next line):\
\n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
)
eval_logger.info(f"Request: {str(inst)}")
def get_sample_size(task, limit: Optional[int]) -> Union[int, None]:
if limit is not None:
limit = (
int(math.ceil(len(task.eval_docs) * limit)) if limit < 1.0 else int(limit)
)
return limit
def prepare_print_tasks(
task_hierarchy: dict, results: dict, tab=0
) -> Tuple[dict, dict]:
"""
@param task_hierarchy: Dictionary representing the group hierarchy of tasks. Each key is a group name and its
value is a list of task names.
@param results: Dictionary containing the results of each task. Each key is a
group name and its value is a dictionary of task results.
@param tab: The indentation level for printing the task
hierarchy. Default is 0.
@return: A tuple of two dictionaries: results_agg and groups_agg. results_agg contains
aggregated results for each task, and groups_agg contains aggregated results for each group.
Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
"""
results_agg = collections.defaultdict(dict)
groups_agg = collections.defaultdict(dict)
(group_name, task_list), *_ = task_hierarchy.items()
task_list = sorted(task_list)
results_agg[group_name] = results[group_name].copy()
# results_agg[group_name]["tab"] = tab
if "samples" in results_agg[group_name]:
results_agg[group_name].pop("samples")
tab_string = " " * tab + "- " if tab > 0 else ""
if "alias" in results_agg[group_name]:
results_agg[group_name]["alias"] = tab_string + results_agg[group_name]["alias"]
else:
results_agg[group_name]["alias"] = tab_string + group_name
if len(task_list) > 0:
groups_agg[group_name] = results[group_name].copy()
# groups_agg[group_name]["tab"] = tab
if "samples" in groups_agg[group_name]:
groups_agg[group_name].pop("samples")
if "alias" in groups_agg[group_name]:
groups_agg[group_name]["alias"] = (
tab_string + groups_agg[group_name]["alias"]
)
else:
groups_agg[group_name]["alias"] = tab_string + group_name
for task_name in task_list:
if task_name in task_hierarchy:
_task_hierarchy = {
**{task_name: task_hierarchy[task_name]},
**task_hierarchy,
}
else:
_task_hierarchy = {
**{task_name: []},
**task_hierarchy,
}
_results_agg, _groups_agg = prepare_print_tasks(
_task_hierarchy, results, tab + 1
)
results_agg = {**results_agg, **_results_agg}
groups_agg = {**groups_agg, **_groups_agg}
return results_agg, groups_agg
def consolidate_results(
eval_tasks: List[TaskOutput],
) -> Tuple[dict, dict, dict, dict, dict]:
"""
@param eval_tasks: list(TaskOutput).
@return: A tuple containing the consolidated results, samples, configs, versions, and num_fewshot.
Consolidates the results of multiple evaluation tasks into a single structure.
The method iterates over each evaluation instance and extracts relevant information to create the consolidated
results structure. The consolidated results structure has the following properties:
- results: A defaultdict with task names as keys and dictionaries as values. Each dictionary contains
metric/filter pairs as keys and corresponding metric values as values. The "alias" key is used to store task
aliases specified in the task configuration.
- samples: A defaultdict with task names as keys and lists of log samples as values.
- configs: A defaultdict with task names as keys and task configurations as values.
- versions: A defaultdict with task names as keys and task versions as values.
- num_fewshot: A defaultdict with task names as keys and number of few-shot samples as values.
The method then returns the consolidated results, samples, configs, versions, and num_fewshot as a tuple.
"""
# stores the final result for each task, for each metric/filter pair.
results = collections.defaultdict(dict)
# logs info about each document evaluated.
samples = collections.defaultdict(list)
# store num-fewshot value per task
num_fewshot = collections.defaultdict(int)
# Tracks the YAML configs of all chosen task
configs = collections.defaultdict(dict)
# Tracks each task's version.
versions = collections.defaultdict(dict)
for task_output in eval_tasks:
if "task_alias" in (task_config := task_output.task_config):
results[task_output.task_name]["alias"] = task_config["task_alias"]
if group_alias := task_output.group_alias:
if group_alias not in results and (group_name := task_output.group_name):
results[group_name]["alias"] = group_alias
num_fewshot[task_output.task_name] = task_output.n_shot
configs[task_output.task_name] = task_output.task_config
versions[task_output.task_name] = task_output.version
samples[task_output.task_name] = task_output.logged_samples
for (metric, filter_key), items in task_output.sample_metrics.items():
metric_key = f"{metric},{filter_key}"
results[task_output.task_name][metric_key] = task_output.agg_metrics[
metric_key
]
results[task_output.task_name]["samples"] = task_output.sample_len
results[task_output.task_name][
f"{metric}_stderr,{filter_key}"
] = task_output.agg_metrics[f"{metric}_stderr,{filter_key}"]
return results, samples, configs, versions, num_fewshot
@positional_deprecated
def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
"""
Search upward in the directory tree to a maximum of three layers
to find and return the package root (containing the 'tests' folder)
"""
cur_path = start_path.resolve()
max_layers = 3
for _ in range(max_layers):
if (cur_path / "tests" / "test_version_stable.py").exists():
return cur_path
else:
cur_path = cur_path.parent.resolve()
raise FileNotFoundError(
f"Unable to find package root within {max_layers} upwards" + f"of {start_path}"
)
@positional_deprecated
def run_task_tests(task_list: List[str]):
"""
Find the package root and run the tests for the given tasks
"""
import pytest
package_root = find_test_root(start_path=pathlib.Path(__file__))
task_string = " or ".join(task_list)
args = [
f"{package_root}/tests/test_version_stable.py",
f"--rootdir={package_root}",
"-k",
f"{task_string}",
]
sys.path.append(str(package_root))
pytest_return_val = pytest.main(args)
if pytest_return_val:
raise ValueError(
f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
)
...@@ -921,7 +921,11 @@ class HFLM(TemplateLM): ...@@ -921,7 +921,11 @@ class HFLM(TemplateLM):
) )
chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn) chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0))) pbar = tqdm(
total=len(requests),
disable=(disable_tqdm or (self.rank != 0)),
desc="Running loglikelihood requests",
)
for chunk in chunks: for chunk in chunks:
inps = [] inps = []
cont_toks_list = [] cont_toks_list = []
...@@ -1089,7 +1093,11 @@ class HFLM(TemplateLM): ...@@ -1089,7 +1093,11 @@ class HFLM(TemplateLM):
toks = self.tok_encode(req[0]) toks = self.tok_encode(req[0])
return -len(toks), req[0] return -len(toks), req[0]
pbar = tqdm(total=len(requests), disable=(self.rank != 0)) pbar = tqdm(
total=len(requests),
disable=(self.rank != 0),
desc="Running generate_until requests",
)
adaptive_batch_size = None adaptive_batch_size = None
if self.batch_size == "auto": if self.batch_size == "auto":
# using rolling window with maximum context # using rolling window with maximum context
......
...@@ -254,7 +254,11 @@ class VLLM(TemplateLM): ...@@ -254,7 +254,11 @@ class VLLM(TemplateLM):
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
) )
pbar = tqdm(total=len(requests), disable=(self.rank != 0)) pbar = tqdm(
total=len(requests),
disable=(self.rank != 0),
desc="Running generate_until requests",
)
# for each different set of kwargs, we execute all requests, by batch. # for each different set of kwargs, we execute all requests, by batch.
for chunk in chunks: for chunk in chunks:
context_and_encoding, all_gen_kwargs = zip(*chunk) context_and_encoding, all_gen_kwargs = zip(*chunk)
...@@ -329,7 +333,11 @@ class VLLM(TemplateLM): ...@@ -329,7 +333,11 @@ class VLLM(TemplateLM):
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
) )
pbar = tqdm(total=len(requests), disable=disable_tqdm) pbar = tqdm(
total=len(requests),
disable=disable_tqdm,
desc="Running loglikelihood requests",
)
for chunk in chunks: for chunk in chunks:
inputs = [] inputs = []
ctxlens = [] ctxlens = []
......
...@@ -6,9 +6,7 @@ import inspect ...@@ -6,9 +6,7 @@ import inspect
import logging import logging
import os import os
import re import re
import sys
from itertools import islice from itertools import islice
from pathlib import Path
from typing import Any, Callable, List from typing import Any, Callable, List
import numpy as np import numpy as np
...@@ -244,7 +242,7 @@ def make_table(result_dict, column: str = "results"): ...@@ -244,7 +242,7 @@ def make_table(result_dict, column: str = "results"):
values = [] values = []
for k, dic in result_dict[column].items(): for k, dic in result_dict[column].items():
version = result_dict["versions"][k] version = result_dict["versions"].get(k, "N/A")
n = str(result_dict["n-shot"][k]) n = str(result_dict["n-shot"][k])
if "alias" in dic: if "alias" in dic:
...@@ -292,47 +290,6 @@ def positional_deprecated(fn): ...@@ -292,47 +290,6 @@ def positional_deprecated(fn):
return _wrapper return _wrapper
@positional_deprecated
def find_test_root(start_path: Path) -> Path:
"""
Search upward in the directory tree to a maximum of three layers
to find and return the package root (containing the 'tests' folder)
"""
cur_path = start_path.resolve()
max_layers = 3
for _ in range(max_layers):
if (cur_path / "tests" / "test_version_stable.py").exists():
return cur_path
else:
cur_path = cur_path.parent.resolve()
raise FileNotFoundError(
f"Unable to find package root within {max_layers} upwards" + f"of {start_path}"
)
@positional_deprecated
def run_task_tests(task_list: List[str]):
"""
Find the package root and run the tests for the given tasks
"""
import pytest
package_root = find_test_root(start_path=Path(__file__))
task_string = " or ".join(task_list)
args = [
f"{package_root}/tests/test_version_stable.py",
f"--rootdir={package_root}",
"-k",
f"{task_string}",
]
sys.path.append(str(package_root))
pytest_return_val = pytest.main(args)
if pytest_return_val:
raise ValueError(
f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
)
def ignore_constructor(loader, node): def ignore_constructor(loader, node):
return node return node
...@@ -414,16 +371,10 @@ def apply_template(template: str, doc: dict) -> str: ...@@ -414,16 +371,10 @@ def apply_template(template: str, doc: dict) -> str:
return rtemplate.render(**doc) return rtemplate.render(**doc)
def create_iterator(raw_iterator, rank, world_size, limit=None): def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None):
""" """
Method for creating a (potentially) sliced and limited Method for creating a (potentially) sliced and limited
iterator from a raw document iterator. Used for splitting data iterator from a raw document iterator. Used for splitting data
among ranks in multigpu setting or only pulling a sample of documents among ranks in multigpu setting or only pulling a sample of documents
""" """
return islice(raw_iterator, rank, limit, world_size) return islice(raw_iterator, rank, limit, world_size)
# Multi-token stopping criteria
# from more_itertools
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment