Unverified Commit 5ccd65d4 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Refactor `evaluater.evaluate` (#1441)



* change `all_gather` to `gather`

* add TaskOutput utility class

* Add FilterResults class and refactor task handling.

* Rename `key` to `filter_key` for clarity

* Add `print_writeout` function in utils.py

* Add function to calculate limit size.

* Add doc_iterator method to Task class

* Refactor `doc_iterator` and cleanup in Task class

* remove superfluous bits

* change `all_gather` to `gather`

* bugfix

* bugfix

* fix `gather`

* Refactor `gather` loop

* Refactor aggregate metrics calculation

* Refactor and simplify aggregate metrics calculation
Removed unused code

* Simplify metrics calculation and remove unused code.

* simplify the metrics calculation in `utils.py` and `evaluator.py`.

* Fix group metric

* change evaluate to hf_evaluate

* change evaluate to hf_evaluate

* add docs

* add docs

* nits

* make isslice keyword only

* nit

* add todo

* nit

* nit

* nit: swap order samples_metrics tuple

* move instance sorting outside loop

* nit

* nit

* Add __repr__ for ConfigurableTask

* nit

* nit

* Revert "nit"

This reverts commit dab8d9977a643752a17f840fd8cf7e4b107df28f.

* fix some logging

* nit

* fix `predict_only` bug. thanks to `@LSinev`!

* change `print_tasks` to `prepare_print_tasks`

* nits

* move eval utils

* move eval utils

* nit

* add comment

* added tqdm descriptions

* Update lm_eval/evaluator_utils.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* fix mgsm bug

* nit

* fix `build_all_requests`

* pre-commit

* add ceil to limit

---------
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent 96d185fa
...@@ -225,7 +225,7 @@ class CachingLM: ...@@ -225,7 +225,7 @@ class CachingLM:
eval_logger.info( eval_logger.info(
f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..." f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..."
) )
for req in tqdm(requests): for req in tqdm(requests, desc="Checking cached requests"):
hsh = hash_args(attr, req.args) hsh = hash_args(attr, req.args)
if attr == "generate_until" and req.args[1].get("do_sample", False): if attr == "generate_until" and req.args[1].get("do_sample", False):
# when we are doing non-greedy generation, don't use the cache # when we are doing non-greedy generation, don't use the cache
...@@ -246,7 +246,9 @@ class CachingLM: ...@@ -246,7 +246,9 @@ class CachingLM:
else: else:
res.append(None) res.append(None)
remaining_reqs.append(req) remaining_reqs.append(req)
eval_logger.info(
f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
)
# actually run the LM on the requests that do not have cached results # actually run the LM on the requests that do not have cached results
rem_res = getattr(self.lm, attr)(remaining_reqs) rem_res = getattr(self.lm, attr)(remaining_reqs)
......
...@@ -7,7 +7,7 @@ from collections.abc import Callable ...@@ -7,7 +7,7 @@ from collections.abc import Callable
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from inspect import getsource from inspect import getsource
from typing import Any, List, Literal, Tuple, Union from typing import Any, Iterator, List, Literal, Tuple, Union
import datasets import datasets
import numpy as np import numpy as np
...@@ -327,7 +327,7 @@ class Task(abc.ABC): ...@@ -327,7 +327,7 @@ class Task(abc.ABC):
return doc return doc
@property @property
def instances(self): def instances(self) -> List[Instance]:
"""After calling `task.build_all_requests()`, tasks """After calling `task.build_all_requests()`, tasks
maintain a list of the dataset instances which will be evaluated. maintain a list of the dataset instances which will be evaluated.
""" """
...@@ -355,6 +355,7 @@ class Task(abc.ABC): ...@@ -355,6 +355,7 @@ class Task(abc.ABC):
def build_all_requests( def build_all_requests(
self, self,
*,
limit=None, limit=None,
rank=None, rank=None,
world_size=None, world_size=None,
...@@ -382,13 +383,6 @@ class Task(abc.ABC): ...@@ -382,13 +383,6 @@ class Task(abc.ABC):
self._instances = flattened_instances self._instances = flattened_instances
return return
if self.has_test_docs():
docs = self.test_docs()
elif self.has_validation_docs():
docs = self.validation_docs()
else:
assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...") eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...")
instances = [] instances = []
...@@ -402,12 +396,7 @@ class Task(abc.ABC): ...@@ -402,12 +396,7 @@ class Task(abc.ABC):
limit = None limit = None
doc_id_docs = list( doc_id_docs = list(
utils.create_iterator( self.doc_iterator(rank=rank, limit=limit, world_size=world_size)
enumerate(docs),
rank,
world_size,
limit,
)
) )
num_docs = len(doc_id_docs) num_docs = len(doc_id_docs)
...@@ -632,6 +621,27 @@ class Task(abc.ABC): ...@@ -632,6 +621,27 @@ class Task(abc.ABC):
setattr(self._config, "metric_list", [{"metric": metric_name}]) setattr(self._config, "metric_list", [{"metric": metric_name}])
setattr(self._config, "process_results", None) setattr(self._config, "process_results", None)
@property
def eval_docs(self) -> Union[datasets.Dataset, List[dict]]:
if self.has_test_docs():
return self.test_docs()
elif self.has_validation_docs():
return self.validation_docs()
else:
assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
def doc_iterator(
self, *, rank: int = 0, limit: Union[int, None] = None, world_size: int = 1
) -> Iterator[Tuple[int, Any]]:
limit = int(limit) if limit else None
doc_iterator = utils.create_iterator(
enumerate(self.eval_docs),
rank=int(rank),
limit=limit,
world_size=int(world_size),
)
return doc_iterator
class ConfigurableTask(Task): class ConfigurableTask(Task):
VERSION = "Yaml" VERSION = "Yaml"
...@@ -781,12 +791,7 @@ class ConfigurableTask(Task): ...@@ -781,12 +791,7 @@ class ConfigurableTask(Task):
else "default" else "default"
)(list(self.fewshot_docs()), self, rnd=random.Random(1234)) )(list(self.fewshot_docs()), self, rnd=random.Random(1234))
if self.has_test_docs(): self.task_docs = self.eval_docs
self.task_docs = self.test_docs()
elif self.has_validation_docs():
self.task_docs = self.validation_docs()
else:
assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
# Test One Doc # Test One Doc
self.features = list(self.task_docs.features.keys()) self.features = list(self.task_docs.features.keys())
...@@ -1336,6 +1341,15 @@ class ConfigurableTask(Task): ...@@ -1336,6 +1341,15 @@ class ConfigurableTask(Task):
def get_config(self, key: str) -> Any: def get_config(self, key: str) -> Any:
return getattr(self._config, key, None) return getattr(self._config, key, None)
def __repr__(self):
return (
f"ConfigurableTask(task_name={getattr(self.config, 'task', None)},"
f"group_name={getattr(self.config, 'group', None)},"
f"output_type={self.OUTPUT_TYPE},"
f"num_fewshot={getattr(self.config, 'num_fewshot', None)},"
f"num_samples={len(self.eval_docs)})"
)
class MultipleChoiceTask(Task): class MultipleChoiceTask(Task):
OUTPUT_TYPE: str = "loglikelihood" OUTPUT_TYPE: str = "loglikelihood"
......
import collections import collections
import itertools import itertools
import logging import logging
import math
import random import random
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
...@@ -11,12 +10,19 @@ import torch ...@@ -11,12 +10,19 @@ import torch
import lm_eval.api.metrics import lm_eval.api.metrics
import lm_eval.api.registry import lm_eval.api.registry
import lm_eval.models import lm_eval.models
from lm_eval.evaluator_utils import (
consolidate_results,
get_sample_size,
get_task_list,
prepare_print_tasks,
print_writeout,
run_task_tests,
)
from lm_eval.logging_utils import add_env_info, get_git_commit_hash from lm_eval.logging_utils import add_env_info, get_git_commit_hash
from lm_eval.tasks import TaskManager, get_task_dict from lm_eval.tasks import TaskManager, get_task_dict
from lm_eval.utils import ( from lm_eval.utils import (
eval_logger, eval_logger,
positional_deprecated, positional_deprecated,
run_task_tests,
simple_parse_args_string, simple_parse_args_string,
) )
...@@ -111,19 +117,23 @@ def simple_evaluate( ...@@ -111,19 +117,23 @@ def simple_evaluate(
eval_logger.info("Deleting requests cache...") eval_logger.info("Deleting requests cache...")
delete_cache() delete_cache()
seed_message = []
if random_seed is not None: if random_seed is not None:
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412 # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412
eval_logger.info(f"Setting random seed to {random_seed}") seed_message.append(f"Setting random seed to {random_seed}")
random.seed(random_seed) random.seed(random_seed)
if numpy_random_seed is not None: if numpy_random_seed is not None:
eval_logger.info(f"Setting numpy seed to {numpy_random_seed}") seed_message.append(f"Setting numpy seed to {numpy_random_seed}")
np.random.seed(numpy_random_seed) np.random.seed(numpy_random_seed)
if torch_random_seed is not None: if torch_random_seed is not None:
eval_logger.info(f"Setting torch manual seed to {torch_random_seed}") seed_message.append(f"Setting torch manual seed to {torch_random_seed}")
torch.manual_seed(torch_random_seed) torch.manual_seed(torch_random_seed)
if seed_message:
eval_logger.info(" | ".join(seed_message))
if tasks is None: if tasks is None:
tasks = [] tasks = []
assert ( assert (
...@@ -166,7 +176,7 @@ def simple_evaluate( ...@@ -166,7 +176,7 @@ def simple_evaluate(
lm = model lm = model
if use_cache is not None: if use_cache is not None:
print(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}") eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
lm = lm_eval.api.model.CachingLM( lm = lm_eval.api.model.CachingLM(
lm, lm,
use_cache use_cache
...@@ -198,13 +208,13 @@ def simple_evaluate( ...@@ -198,13 +208,13 @@ def simple_evaluate(
key="generation_kwargs", value=gen_kwargs, update=True key="generation_kwargs", value=gen_kwargs, update=True
) )
if predict_only: if predict_only:
log_samples = True log_samples = True
eval_logger.info( eval_logger.info(
f"Processing {task_name} in output-only mode. Metrics will not be calculated!" f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
) )
# we have to change the class properties post-hoc. This is pretty hacky. # we have to change the class properties post-hoc. This is pretty hacky.
task_obj.override_metric(metric_name="bypass") task_obj.override_metric(metric_name="bypass")
if num_fewshot is not None: if num_fewshot is not None:
if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0: if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
...@@ -299,82 +309,22 @@ def evaluate( ...@@ -299,82 +309,22 @@ def evaluate(
eval_logger.setLevel(getattr(logging, f"{verbosity}")) eval_logger.setLevel(getattr(logging, f"{verbosity}"))
# decontaminate = decontamination_ngrams_path is not None # decontaminate = decontamination_ngrams_path is not None
for task_name, task in task_dict.items():
if isinstance(task, tuple):
_, task = task
if not log_samples:
assert (
"bypass" not in getattr(task, "_metric_fn_list", {}).keys()
), f"log_samples must be True for 'bypass' only tasks: {task_name}"
# stores the final result for each task, for each metric/filter pair.
results = collections.defaultdict(dict)
# Tracks each task's version.
versions = collections.defaultdict(dict)
# Tracks the YAML configs of all chosen tasks.
configs = collections.defaultdict(dict)
# logs info about each document evaluated.
samples = collections.defaultdict(list)
# tracks all Instances/requests a model must generate output on. # tracks all Instances/requests a model must generate output on.
requests = collections.defaultdict(list) requests = collections.defaultdict(list)
# Aggregated task scores presented with groups
results_agg = collections.defaultdict(dict)
# Aggregated groups scores only
groups_agg = collections.defaultdict(dict)
# stores the amount to pad out reqs per req. type so that # stores the amount to pad out reqs per req. type so that
# number of fwd passes per distributed rank is equal # number of fwd passes per distributed rank is equal
padding_requests = collections.defaultdict(int) padding_requests = collections.defaultdict(int)
# store the hierarchy to do proper ordering
task_hierarchy = collections.defaultdict(list)
# store num-fewshot value per task
num_fewshot = collections.defaultdict(int)
# get lists of each type of request
for task_name, task in task_dict.items():
task: Task
if isinstance(task, tuple):
group_name, task = task
task_hierarchy[group_name].append(task_name)
versions[group_name] = "N/A"
else:
group_name = None
task_hierarchy[task_name] = []
if task is None:
continue
versions[task_name] = task.VERSION
configs[task_name] = dict(task.dump_config())
# Number of few-shots for printing.
if (n_shot := configs[task_name].get("num_fewshot")) == 0:
n_shot = configs[task_name].get("metadata", {}).get("num_fewshot", 0)
num_fewshot[task_name] = n_shot
if "task_alias" in configs[task_name]:
results[task_name]["alias"] = configs[task_name]["task_alias"]
if (
("group_alias" in configs[task_name])
and (group_name not in results)
and (group_name is not None)
):
results[group_name]["alias"] = configs[task_name]["group_alias"]
if limit is not None:
if task.has_test_docs():
task_docs = task.test_docs()
elif task.has_validation_docs():
task_docs = task.validation_docs()
else:
raise RuntimeError("Task has neither test_docs nor validation_docs")
num_docs = len(task_docs) * limit
# ceil to prevent limit being equal to 0
limit = int(math.ceil(num_docs)) if limit < 1.0 else int(limit)
# get lists of group hierarchy and each type of request
task_hierarchy, eval_tasks = get_task_list(task_dict)
if not log_samples:
assert all(
"bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
for task_output in eval_tasks
), "log_samples must be True for 'bypass' only tasks"
for task_output in eval_tasks:
task: Task = task_output.task
limit = get_sample_size(task, limit)
task.build_all_requests( task.build_all_requests(
limit=limit, limit=limit,
rank=lm.rank, rank=lm.rank,
...@@ -382,21 +332,12 @@ def evaluate( ...@@ -382,21 +332,12 @@ def evaluate(
cache_requests=cache_requests, cache_requests=cache_requests,
rewrite_requests_cache=rewrite_requests_cache, rewrite_requests_cache=rewrite_requests_cache,
) )
eval_logger.debug( eval_logger.debug(
f"Task: {task_name}; number of requests on this rank: {len(task.instances)}" f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}"
) )
if write_out: if write_out:
for inst in task.instances: print_writeout(task)
# print the prompt for the first few documents
if inst.doc_id < 1:
eval_logger.info(
f"Task: {task_name}; document {inst.doc_id}; context prompt (starting on next line):\
\n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
)
eval_logger.info(f"Request: {str(inst)}")
# aggregate Instances by LM method requested to get output. # aggregate Instances by LM method requested to get output.
for instance in task.instances: for instance in task.instances:
reqtype = instance.request_type reqtype = instance.request_type
...@@ -408,7 +349,7 @@ def evaluate( ...@@ -408,7 +349,7 @@ def evaluate(
lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist() lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
) )
# compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks) # compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks)
numpad = max(gathered_item) - gathered_item[lm.rank] numpad = max(gathered_item) - gathered_item[lm.rank]
padding_requests[task.OUTPUT_TYPE] += numpad padding_requests[task.OUTPUT_TYPE] += numpad
...@@ -435,42 +376,33 @@ def evaluate( ...@@ -435,42 +376,33 @@ def evaluate(
if lm.world_size > 1: if lm.world_size > 1:
lm.accelerator.wait_for_everyone() lm.accelerator.wait_for_everyone()
RANK = lm.rank
WORLD_SIZE = lm.world_size
### Postprocess outputs ### ### Postprocess outputs ###
# TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately) # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
for task_name, task in task_dict.items(): for task_output in eval_tasks:
if isinstance(task, tuple): task = task_output.task
group, task = task
if task is None:
continue
task.apply_filters() task.apply_filters()
### Collect values of metrics on all datapoints ### ### Collect values of metrics on all datapoints ###
vals = collections.defaultdict(list) # # unpack results and sort back in order and return control to Task
# unpack results and sort back in order and return control to Task
for task_name, task in task_dict.items():
if isinstance(task, tuple):
group, task = task
if task is None:
continue
# TODO: make it possible to use a different metric per filter # TODO: make it possible to use a different metric per filter
# Pre-process task.instances to group by doc_id
instances_by_doc_id = collections.defaultdict(list)
for instance in task.instances:
instances_by_doc_id[instance.doc_id].append(instance)
# Sort instances within each group
for instances in instances_by_doc_id.values():
instances.sort(key=lambda x: x.idx)
# iterate over different filters used # iterate over different filters used
for key in task.instances[0].filtered_resps.keys(): for filter_key in task.instances[0].filtered_resps.keys():
doc_iterator = ( doc_iterator = task.doc_iterator(
itertools.islice( rank=RANK, limit=limit, world_size=WORLD_SIZE
enumerate(task.test_docs()), lm.rank, limit, lm.world_size
)
if task.has_test_docs()
else itertools.islice(
enumerate(task.validation_docs()), lm.rank, limit, lm.world_size
)
) )
for doc_id, doc in doc_iterator: for doc_id, doc in doc_iterator:
# subset instances to only this document id ; sort by idx requests = instances_by_doc_id[doc_id]
requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
requests.sort(key=lambda x: x.idx)
metrics = task.process_results( metrics = task.process_results(
doc, [req.filtered_resps[key] for req in requests] doc, [req.filtered_resps[filter_key] for req in requests]
) )
if log_samples: if log_samples:
target = task.doc_to_target(doc) target = task.doc_to_target(doc)
...@@ -480,93 +412,56 @@ def evaluate( ...@@ -480,93 +412,56 @@ def evaluate(
"target": target, "target": target,
"arguments": [req.args for req in requests], "arguments": [req.args for req in requests],
"resps": [req.resps for req in requests], "resps": [req.resps for req in requests],
"filtered_resps": [req.filtered_resps[key] for req in requests], "filtered_resps": [
req.filtered_resps[filter_key] for req in requests
],
} }
example.update(metrics) example.update(metrics)
samples[task_name].append(example) task_output.logged_samples.append(example)
for metric, value in metrics.items(): for metric, value in metrics.items():
vals[(task_name, key, metric)].append(value) task_output.sample_metrics[(metric, filter_key)].append(value)
if lm.world_size > 1: if WORLD_SIZE > 1:
# if multigpu, then gather data across all ranks # if multigpu, then gather data across all ranks to rank 0
# first gather logged samples across all ranks # first gather logged samples across all ranks
for task_name, task_samples in list(samples.items()): for task_output in eval_tasks:
full_samples = [None] * lm.world_size if log_samples:
torch.distributed.all_gather_object(full_samples, task_samples) # for task_name, task_samples in list(samples.items()):
full_samples = [None] * WORLD_SIZE if RANK == 0 else None
samples[task_name] = list(itertools.chain.from_iterable(full_samples)) torch.distributed.gather_object(
obj=task_output.logged_samples,
# then collect metrics across all ranks object_gather_list=full_samples,
vals_torch = collections.defaultdict(list) dst=0,
for (task_name, key, metric), items in vals.items():
numitem = 0
if isinstance(items[0], tuple):
numitem = len(items[0])
if isinstance(items[0], (str, list, tuple)):
# handle the string case
gathered_items = [None] * lm.accelerator.num_processes
torch.distributed.all_gather_object(gathered_items, items)
gathered_item = list(itertools.chain.from_iterable(gathered_items))
else:
# distributed gather requires all ranks to have same dimensions
# so we pad out with float32 min value
pad_value = torch.finfo(torch.float32).min
metrics_tensor = torch.tensor(items, device=lm.device)
original_dtype = metrics_tensor.dtype # store original dtype
torch_device_tensor = lm.accelerator.pad_across_processes(
metrics_tensor.to(torch.float32), pad_index=pad_value
) )
gathered_item = lm.accelerator.gather(torch_device_tensor)
if numitem > 0: if RANK == 0:
gathered_filtered = gathered_item[gathered_item[:, 0] != pad_value] task_output.logged_samples = list(
else: itertools.chain.from_iterable(full_samples)
gathered_filtered = gathered_item[gathered_item != pad_value] )
gathered_item = ( # then collect metrics across all ranks
gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist() for metrics in task_output.sample_metrics:
metric_list = [None] * WORLD_SIZE if RANK == 0 else None
torch.distributed.gather_object(
obj=task_output.sample_metrics[metrics],
object_gather_list=metric_list,
dst=0,
) )
# reconvert if we were passed a tuple of values if RANK == 0:
if numitem > 0: task_output.sample_metrics[metrics] = list(
gathered_item = [tuple(g) for g in gathered_item] itertools.chain.from_iterable(metric_list)
)
if lm.rank == 0:
vals_torch[(task_name, key, metric)] = gathered_item
vals = vals_torch
if lm.rank == 0: if RANK == 0:
### Aggregate results over all datapoints ### ### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs # aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items(): for task_output in eval_tasks:
task = task_dict[task_name] task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters)
group_name, task = task if isinstance(task, tuple) else (None, task) results, samples, configs, versions, num_fewshot = consolidate_results(
eval_tasks
metric_key = f"{metric},{key}" )
agg_fn = task.aggregation()[metric]
results[task_name][metric_key] = agg_fn(items)
results[task_name]["samples"] = len(items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
if bootstrap_iters > 0:
stderr_fn = lm_eval.api.metrics.stderr_for_metric(
metric=agg_fn,
bootstrap_iters=(
min(bootstrap_iters, 100)
if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters
),
)
results[task_name][f"{metric}_stderr,{key}"] = (
stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A"
)
### Calculate group metrics ###
if bool(results): if bool(results):
for group, task_list in reversed(task_hierarchy.items()): for group, task_list in reversed(task_hierarchy.items()):
if len(task_list) == 0: if len(task_list) == 0:
...@@ -575,19 +470,33 @@ def evaluate( ...@@ -575,19 +470,33 @@ def evaluate(
# or `task_name: []`. # or `task_name: []`.
# we only want to operate on groups here. # we only want to operate on groups here.
continue continue
for metric in [ metric_list = list(
key {
for key in results[task_list[0]].keys() key
if "_stderr" not in key and key not in ["alias", "samples"] for task in task_list
]: # TODO: what if tasks don't all share the same metrics for key in results[task].keys()
if "_stderr" not in key and key not in ["alias", "samples"]
}
)
for metric in metric_list:
stderr = "_stderr,".join(metric.split(",")) stderr = "_stderr,".join(metric.split(","))
# gather metrics, sizes, and stderrs from subtasks # gather metrics, sizes, and stderrs from subtasks
metrics = [ metrics = [
results[task][metric] for task in task_list results[task][metric]
for task in task_list
if metric in results[task]
] # TODO: copy? ] # TODO: copy?
stderrs = [results[task][stderr] for task in task_list] stderrs = [
sizes = [results[task]["samples"] for task in task_list] results[task][stderr]
for task in task_list
if stderr in results[task]
]
sizes = [
results[task]["samples"]
for task in task_list
if metric in results[task]
]
# compute group's pooled metric and stderr # compute group's pooled metric and stderr
results[group][ results[group][
...@@ -606,60 +515,6 @@ def evaluate( ...@@ -606,60 +515,6 @@ def evaluate(
results[group]["samples"] = sum(sizes) results[group]["samples"] = sum(sizes)
def print_tasks(task_hierarchy, results, tab=0):
results_agg = collections.defaultdict(dict)
groups_agg = collections.defaultdict(dict)
(group_name, task_list), *_ = task_hierarchy.items()
task_list = sorted(task_list)
results_agg[group_name] = results[group_name].copy()
# results_agg[group_name]["tab"] = tab
if "samples" in results_agg[group_name]:
results_agg[group_name].pop("samples")
tab_string = " " * tab + "- " if tab > 0 else ""
if "alias" in results_agg[group_name]:
results_agg[group_name]["alias"] = (
tab_string + results_agg[group_name]["alias"]
)
else:
results_agg[group_name]["alias"] = tab_string + group_name
if len(task_list) > 0:
groups_agg[group_name] = results[group_name].copy()
# groups_agg[group_name]["tab"] = tab
if "samples" in groups_agg[group_name]:
groups_agg[group_name].pop("samples")
if "alias" in groups_agg[group_name]:
groups_agg[group_name]["alias"] = (
tab_string + groups_agg[group_name]["alias"]
)
else:
groups_agg[group_name]["alias"] = tab_string + group_name
for task_name in task_list:
if task_name in task_hierarchy:
_task_hierarchy = {
**{task_name: task_hierarchy[task_name]},
**task_hierarchy,
}
else:
_task_hierarchy = {
**{task_name: []},
**task_hierarchy,
}
_results_agg, _groups_agg = print_tasks(
_task_hierarchy, results, tab + 1
)
results_agg = {**results_agg, **_results_agg}
groups_agg = {**groups_agg, **_groups_agg}
return results_agg, groups_agg
results_agg = collections.defaultdict(dict) results_agg = collections.defaultdict(dict)
groups_agg = collections.defaultdict(dict) groups_agg = collections.defaultdict(dict)
all_tasks_list = list(task_hierarchy.keys()) all_tasks_list = list(task_hierarchy.keys())
...@@ -673,7 +528,7 @@ def evaluate( ...@@ -673,7 +528,7 @@ def evaluate(
_task_hierarchy = { _task_hierarchy = {
k: v for k, v in task_hierarchy.items() if k in left_tasks_list k: v for k, v in task_hierarchy.items() if k in left_tasks_list
} }
_results_agg, _groups_agg = print_tasks(_task_hierarchy, results) _results_agg, _groups_agg = prepare_print_tasks(_task_hierarchy, results)
results_agg = {**results_agg, **_results_agg} results_agg = {**results_agg, **_results_agg}
groups_agg = {**groups_agg, **_groups_agg} groups_agg = {**groups_agg, **_groups_agg}
......
import collections
import math
import pathlib
import sys
from typing import Dict, List, Optional, Tuple, Union
from lm_eval.api import metrics
from lm_eval.utils import eval_logger, positional_deprecated
class TaskOutput:
"""
Wrapper class for Task outputs.It contains various attributes and methods to manage and calculate metrics for the task.
Attributes:
task (object): The task object.
task_name (str): The name of the task.
task_config (dict): The configuration of the task.
version (str): The version of the task.
group_name (str): The name of the task group.
n_shot (int): The number of shots for the task.
task_alias (str): The alias of the task.
group_alias (str): The alias of the task group.
is_group (bool): Indicates if the task is a group.
logged_samples (list): The list of logged samples.
sample_len (int): The length of the samples.
sample_metrics (defaultdict): The dictionary of samples' metrics.
agg_metrics (defaultdict): The dictionary of aggregate metrics.
Methods:
from_taskdict(cls, task_name: str, task):
Creates a TaskOutput instance from a task dictionary.
calculate_aggregate_metric(bootstrap_iters=100000) -> None:
Calculates the aggregate metrics for the task.
"""
def __init__(
self,
task=None,
task_name=None,
task_config=None,
version=None,
group_name=None,
n_shot=None,
task_alias=None,
group_alias=None,
is_group=None,
):
self.task = task
self.task_config = task_config
self.task_name = task_name
self.group_name = group_name
self.version = version
self.n_shot = n_shot
self.task_alias = task_alias
self.group_alias = group_alias
self.is_group = is_group
self.logged_samples = []
self.sample_len = None
self.sample_metrics = collections.defaultdict(list)
self.agg_metrics = collections.defaultdict(list)
@classmethod
def from_taskdict(cls, task_name: str, task):
if isinstance(task, tuple):
group_name, task = task
else:
group_name = None
if not task:
# these gets filtered out in get_task_list
# once they are added to group hierarchy
is_group = True
return cls(
task=task, task_name=task_name, is_group=is_group, group_name=group_name
)
version = task.VERSION
task_config = dict(task.dump_config())
if (n_shot := task_config.get("num_fewshot")) == 0:
n_shot = task_config.get("metadata", {}).get("num_fewshot", 0)
task_alias = task_config.get("alias")
group_alias = task_config.get("group_alias")
return cls(
task=task,
task_name=task_name,
task_config=task_config,
group_name=group_name,
version=version,
n_shot=n_shot,
task_alias=task_alias,
group_alias=group_alias,
)
def calculate_aggregate_metric(self, bootstrap_iters=100000) -> None:
for (metric, filter_key), items in self.sample_metrics.items():
agg_fn = self.task.aggregation()[metric]
metric_key = f"{metric},{filter_key}"
self.agg_metrics[metric_key] = agg_fn(items)
self.sample_len = len(items) # TODO: same sample size for each metric?
if bootstrap_iters:
stderr_fn = metrics.stderr_for_metric(
metric=agg_fn,
bootstrap_iters=min(bootstrap_iters, 100)
if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters,
)
self.agg_metrics[f"{metric}_stderr,{filter_key}"] = (
stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A"
)
def __repr__(self):
return (
f"TaskOutput(task_name={self.task_name}, "
f"group_name={self.group_name}, "
f"version={self.version},"
f"n_shot={self.n_shot}"
f"task_alias={self.task_alias}, group_alias={self.group_alias})"
)
def get_task_list(task_dict: dict) -> Tuple[Dict[str, list], List[TaskOutput]]:
task_hierarchy = collections.defaultdict(list)
outputs = list(TaskOutput.from_taskdict(x, y) for x, y in task_dict.items())
for task_output in outputs:
if group_name := task_output.group_name:
task_hierarchy[group_name].append(task_output.task_name)
else:
task_hierarchy[task_output.task_name] = []
# returns task_hierarchy tracking which groups contain which subtasks,
# and a list of TaskOutput classes for each non-group subtask
return task_hierarchy, [x for x in outputs if x.task]
def print_writeout(task) -> None:
for inst in task.instances:
# print the prompt for the first few documents
if inst.doc_id < 1:
eval_logger.info(
f"Task: {task}; document {inst.doc_id}; context prompt (starting on next line):\
\n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
)
eval_logger.info(f"Request: {str(inst)}")
def get_sample_size(task, limit: Optional[int]) -> Union[int, None]:
if limit is not None:
limit = (
int(math.ceil(len(task.eval_docs) * limit)) if limit < 1.0 else int(limit)
)
return limit
def prepare_print_tasks(
task_hierarchy: dict, results: dict, tab=0
) -> Tuple[dict, dict]:
"""
@param task_hierarchy: Dictionary representing the group hierarchy of tasks. Each key is a group name and its
value is a list of task names.
@param results: Dictionary containing the results of each task. Each key is a
group name and its value is a dictionary of task results.
@param tab: The indentation level for printing the task
hierarchy. Default is 0.
@return: A tuple of two dictionaries: results_agg and groups_agg. results_agg contains
aggregated results for each task, and groups_agg contains aggregated results for each group.
Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
"""
results_agg = collections.defaultdict(dict)
groups_agg = collections.defaultdict(dict)
(group_name, task_list), *_ = task_hierarchy.items()
task_list = sorted(task_list)
results_agg[group_name] = results[group_name].copy()
# results_agg[group_name]["tab"] = tab
if "samples" in results_agg[group_name]:
results_agg[group_name].pop("samples")
tab_string = " " * tab + "- " if tab > 0 else ""
if "alias" in results_agg[group_name]:
results_agg[group_name]["alias"] = tab_string + results_agg[group_name]["alias"]
else:
results_agg[group_name]["alias"] = tab_string + group_name
if len(task_list) > 0:
groups_agg[group_name] = results[group_name].copy()
# groups_agg[group_name]["tab"] = tab
if "samples" in groups_agg[group_name]:
groups_agg[group_name].pop("samples")
if "alias" in groups_agg[group_name]:
groups_agg[group_name]["alias"] = (
tab_string + groups_agg[group_name]["alias"]
)
else:
groups_agg[group_name]["alias"] = tab_string + group_name
for task_name in task_list:
if task_name in task_hierarchy:
_task_hierarchy = {
**{task_name: task_hierarchy[task_name]},
**task_hierarchy,
}
else:
_task_hierarchy = {
**{task_name: []},
**task_hierarchy,
}
_results_agg, _groups_agg = prepare_print_tasks(
_task_hierarchy, results, tab + 1
)
results_agg = {**results_agg, **_results_agg}
groups_agg = {**groups_agg, **_groups_agg}
return results_agg, groups_agg
def consolidate_results(
eval_tasks: List[TaskOutput],
) -> Tuple[dict, dict, dict, dict, dict]:
"""
@param eval_tasks: list(TaskOutput).
@return: A tuple containing the consolidated results, samples, configs, versions, and num_fewshot.
Consolidates the results of multiple evaluation tasks into a single structure.
The method iterates over each evaluation instance and extracts relevant information to create the consolidated
results structure. The consolidated results structure has the following properties:
- results: A defaultdict with task names as keys and dictionaries as values. Each dictionary contains
metric/filter pairs as keys and corresponding metric values as values. The "alias" key is used to store task
aliases specified in the task configuration.
- samples: A defaultdict with task names as keys and lists of log samples as values.
- configs: A defaultdict with task names as keys and task configurations as values.
- versions: A defaultdict with task names as keys and task versions as values.
- num_fewshot: A defaultdict with task names as keys and number of few-shot samples as values.
The method then returns the consolidated results, samples, configs, versions, and num_fewshot as a tuple.
"""
# stores the final result for each task, for each metric/filter pair.
results = collections.defaultdict(dict)
# logs info about each document evaluated.
samples = collections.defaultdict(list)
# store num-fewshot value per task
num_fewshot = collections.defaultdict(int)
# Tracks the YAML configs of all chosen task
configs = collections.defaultdict(dict)
# Tracks each task's version.
versions = collections.defaultdict(dict)
for task_output in eval_tasks:
if "task_alias" in (task_config := task_output.task_config):
results[task_output.task_name]["alias"] = task_config["task_alias"]
if group_alias := task_output.group_alias:
if group_alias not in results and (group_name := task_output.group_name):
results[group_name]["alias"] = group_alias
num_fewshot[task_output.task_name] = task_output.n_shot
configs[task_output.task_name] = task_output.task_config
versions[task_output.task_name] = task_output.version
samples[task_output.task_name] = task_output.logged_samples
for (metric, filter_key), items in task_output.sample_metrics.items():
metric_key = f"{metric},{filter_key}"
results[task_output.task_name][metric_key] = task_output.agg_metrics[
metric_key
]
results[task_output.task_name]["samples"] = task_output.sample_len
results[task_output.task_name][
f"{metric}_stderr,{filter_key}"
] = task_output.agg_metrics[f"{metric}_stderr,{filter_key}"]
return results, samples, configs, versions, num_fewshot
@positional_deprecated
def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
"""
Search upward in the directory tree to a maximum of three layers
to find and return the package root (containing the 'tests' folder)
"""
cur_path = start_path.resolve()
max_layers = 3
for _ in range(max_layers):
if (cur_path / "tests" / "test_version_stable.py").exists():
return cur_path
else:
cur_path = cur_path.parent.resolve()
raise FileNotFoundError(
f"Unable to find package root within {max_layers} upwards" + f"of {start_path}"
)
@positional_deprecated
def run_task_tests(task_list: List[str]):
"""
Find the package root and run the tests for the given tasks
"""
import pytest
package_root = find_test_root(start_path=pathlib.Path(__file__))
task_string = " or ".join(task_list)
args = [
f"{package_root}/tests/test_version_stable.py",
f"--rootdir={package_root}",
"-k",
f"{task_string}",
]
sys.path.append(str(package_root))
pytest_return_val = pytest.main(args)
if pytest_return_val:
raise ValueError(
f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
)
...@@ -921,7 +921,11 @@ class HFLM(TemplateLM): ...@@ -921,7 +921,11 @@ class HFLM(TemplateLM):
) )
chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn) chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0))) pbar = tqdm(
total=len(requests),
disable=(disable_tqdm or (self.rank != 0)),
desc="Running loglikelihood requests",
)
for chunk in chunks: for chunk in chunks:
inps = [] inps = []
cont_toks_list = [] cont_toks_list = []
...@@ -1089,7 +1093,11 @@ class HFLM(TemplateLM): ...@@ -1089,7 +1093,11 @@ class HFLM(TemplateLM):
toks = self.tok_encode(req[0]) toks = self.tok_encode(req[0])
return -len(toks), req[0] return -len(toks), req[0]
pbar = tqdm(total=len(requests), disable=(self.rank != 0)) pbar = tqdm(
total=len(requests),
disable=(self.rank != 0),
desc="Running generate_until requests",
)
adaptive_batch_size = None adaptive_batch_size = None
if self.batch_size == "auto": if self.batch_size == "auto":
# using rolling window with maximum context # using rolling window with maximum context
......
...@@ -254,7 +254,11 @@ class VLLM(TemplateLM): ...@@ -254,7 +254,11 @@ class VLLM(TemplateLM):
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
) )
pbar = tqdm(total=len(requests), disable=(self.rank != 0)) pbar = tqdm(
total=len(requests),
disable=(self.rank != 0),
desc="Running generate_until requests",
)
# for each different set of kwargs, we execute all requests, by batch. # for each different set of kwargs, we execute all requests, by batch.
for chunk in chunks: for chunk in chunks:
context_and_encoding, all_gen_kwargs = zip(*chunk) context_and_encoding, all_gen_kwargs = zip(*chunk)
...@@ -329,7 +333,11 @@ class VLLM(TemplateLM): ...@@ -329,7 +333,11 @@ class VLLM(TemplateLM):
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
) )
pbar = tqdm(total=len(requests), disable=disable_tqdm) pbar = tqdm(
total=len(requests),
disable=disable_tqdm,
desc="Running loglikelihood requests",
)
for chunk in chunks: for chunk in chunks:
inputs = [] inputs = []
ctxlens = [] ctxlens = []
......
...@@ -6,9 +6,7 @@ import inspect ...@@ -6,9 +6,7 @@ import inspect
import logging import logging
import os import os
import re import re
import sys
from itertools import islice from itertools import islice
from pathlib import Path
from typing import Any, Callable, List from typing import Any, Callable, List
import numpy as np import numpy as np
...@@ -244,7 +242,7 @@ def make_table(result_dict, column: str = "results"): ...@@ -244,7 +242,7 @@ def make_table(result_dict, column: str = "results"):
values = [] values = []
for k, dic in result_dict[column].items(): for k, dic in result_dict[column].items():
version = result_dict["versions"][k] version = result_dict["versions"].get(k, "N/A")
n = str(result_dict["n-shot"][k]) n = str(result_dict["n-shot"][k])
if "alias" in dic: if "alias" in dic:
...@@ -292,47 +290,6 @@ def positional_deprecated(fn): ...@@ -292,47 +290,6 @@ def positional_deprecated(fn):
return _wrapper return _wrapper
@positional_deprecated
def find_test_root(start_path: Path) -> Path:
"""
Search upward in the directory tree to a maximum of three layers
to find and return the package root (containing the 'tests' folder)
"""
cur_path = start_path.resolve()
max_layers = 3
for _ in range(max_layers):
if (cur_path / "tests" / "test_version_stable.py").exists():
return cur_path
else:
cur_path = cur_path.parent.resolve()
raise FileNotFoundError(
f"Unable to find package root within {max_layers} upwards" + f"of {start_path}"
)
@positional_deprecated
def run_task_tests(task_list: List[str]):
"""
Find the package root and run the tests for the given tasks
"""
import pytest
package_root = find_test_root(start_path=Path(__file__))
task_string = " or ".join(task_list)
args = [
f"{package_root}/tests/test_version_stable.py",
f"--rootdir={package_root}",
"-k",
f"{task_string}",
]
sys.path.append(str(package_root))
pytest_return_val = pytest.main(args)
if pytest_return_val:
raise ValueError(
f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
)
def ignore_constructor(loader, node): def ignore_constructor(loader, node):
return node return node
...@@ -414,16 +371,10 @@ def apply_template(template: str, doc: dict) -> str: ...@@ -414,16 +371,10 @@ def apply_template(template: str, doc: dict) -> str:
return rtemplate.render(**doc) return rtemplate.render(**doc)
def create_iterator(raw_iterator, rank, world_size, limit=None): def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None):
""" """
Method for creating a (potentially) sliced and limited Method for creating a (potentially) sliced and limited
iterator from a raw document iterator. Used for splitting data iterator from a raw document iterator. Used for splitting data
among ranks in multigpu setting or only pulling a sample of documents among ranks in multigpu setting or only pulling a sample of documents
""" """
return islice(raw_iterator, rank, limit, world_size) return islice(raw_iterator, rank, limit, world_size)
# Multi-token stopping criteria
# from more_itertools
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment