Unverified Commit 4ee1b386 authored by LSinev's avatar LSinev Committed by GitHub
Browse files

Cleanup and fixes (Task, Instance, and a little bit of *evaluate) (#1533)



* Remove unused `decontamination_ngrams_path` and all mentions (still no alternative path provided)

* Fix improper import of LM and usage of evaluator in one of scripts

* update type hints in instance and task api

* raising errors in task.py instead of asserts

* Fix warnings from ruff

* raising errors in __main__.py instead of asserts

* raising errors in tasks/__init__.py instead of asserts

* raising errors in evaluator.py instead of asserts

* evaluator: update type hints and remove unused variables in code

* Update lm_eval/__main__.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* Update lm_eval/__main__.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* Update lm_eval/api/task.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* Update lm_eval/api/task.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* Update lm_eval/api/task.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* Update lm_eval/evaluator.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* pre-commit induced fixes

---------
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent 02705057
......@@ -2,15 +2,14 @@
## Usage
Simply add a "--decontamination_ngrams_path" when running \__main\__.py. The provided directory should contain
The provided directory should contain
the ngram files and info.json produced in "Pile Ngram Generation" further down.
```bash
python -m lm_eval \
--model gpt2 \
--device 0 \
--tasks sciq \
--decontamination_ngrams_path path/containing/training/set/ngrams
--tasks sciq
```
## Background
......@@ -70,5 +69,3 @@ python -m scripts/clean_training_data/compress_and_package \
-output path/to/final/directory \
-procs 8
```
Congratulations, the final directory can now be passed to lm-evaulation-harness with the "--decontamination_ngrams_path" argument.
......@@ -36,8 +36,6 @@ This mode supports a number of command-line arguments, the details of which can
- `--cache_requests` : Can be "true", "refresh", or "delete". "true" means that the cache should be used. "refresh" means that you wish to regenerate the cache, which you should run if you change your dataset configuration for a given task. "delete" will delete the cache. Cached files are stored under lm_eval/cache/.cache unless you specify a different path via the environment variable: `LM_HARNESS_CACHE_PATH`. e.g. `LM_HARNESS_CACHE_PATH=~/Documents/cache_for_lm_harness`.
- `--decontamination_ngrams_path` : Deprecated, see (this commit)[https://github.com/EleutherAI/lm-evaluation-harness/commit/00209e10f6e27edf5d766145afaf894079b5fe10] or older for a working decontamination-checker tool.
- `--check_integrity` : If this flag is used, the library tests for each task selected are run to confirm task integrity.
- `--write_out` : Used for diagnostic purposes to observe the format of task documents passed to a model. If this flag is used, then prints the prompt and gold target string for the first document of each task.
......
......@@ -17,6 +17,9 @@ from lm_eval.tasks import TaskManager, include_path, initialize_tasks
from lm_eval.utils import make_table, simple_parse_args_string
DEFAULT_RESULTS_FILE = "results.json"
def _handle_non_serializable(o):
if isinstance(o, np.int64) or isinstance(o, np.int32):
return int(o)
......@@ -127,7 +130,6 @@ def parse_eval_args() -> argparse.Namespace:
choices=["true", "refresh", "delete"],
help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
)
parser.add_argument("--decontamination_ngrams_path", default=None) # TODO: not used
parser.add_argument(
"--check_integrity",
action="store_true",
......@@ -226,7 +228,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if args.predict_only:
args.log_samples = True
if (args.log_samples or args.predict_only) and not args.output_path:
assert args.output_path, "Specify --output_path"
raise ValueError(
"Specify --output_path if providing --log_samples or --predict_only"
)
initialize_tasks(args.verbosity)
task_manager = TaskManager(args.verbosity, include_path=args.include_path)
......@@ -281,12 +285,13 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if args.output_path:
path = Path(args.output_path)
# check if file or 'dir/results.json' exists
if path.is_file() or Path(args.output_path).joinpath("results.json").is_file():
if path.is_file():
raise FileExistsError(f"File already exists at {path}")
output_path_file = path.joinpath(DEFAULT_RESULTS_FILE)
if output_path_file.is_file():
eval_logger.warning(
f"File already exists at {path}. Results will be overwritten."
f"File {output_path_file} already exists. Results will be overwritten."
)
output_path_file = path.joinpath("results.json")
assert not path.is_file(), "File already exists"
# if path json then get parent dir
elif path.suffix in (".json", ".jsonl"):
output_path_file = path
......@@ -294,7 +299,6 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
path = path.parent
else:
path.mkdir(parents=True, exist_ok=True)
output_path_file = path.joinpath("results.json")
# Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
if args.trust_remote_code:
......@@ -321,17 +325,16 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
device=args.device,
use_cache=args.use_cache,
limit=args.limit,
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity,
write_out=args.write_out,
log_samples=args.log_samples,
gen_kwargs=args.gen_kwargs,
task_manager=task_manager,
predict_only=args.predict_only,
**request_caching_args,
random_seed=args.seed[0],
numpy_random_seed=args.seed[1],
torch_random_seed=args.seed[2],
**request_caching_args,
)
if results is not None:
......
from dataclasses import dataclass, field
from typing import Literal, Tuple
from typing import Literal, Optional, Tuple
OutputType = Literal[
"loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice"
]
@dataclass
class Instance:
request_type: Literal[
"loglikelihood",
"loglikelihood_rolling",
"generate_until",
"multiple_choice",
]
request_type: OutputType
doc: dict
arguments: tuple
idx: int
metadata: Tuple[str, int, int] = field(
metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
default_factory=lambda: (None, None, None)
) # TODO: better typehints here
)
resps: list = field(default_factory=list)
filtered_resps: dict = field(default_factory=dict)
# initialized after init
task_name: str = None
doc_id: str = None
repeats: str = None
task_name: Optional[str] = None
doc_id: Optional[int] = None
repeats: Optional[int] = None
def __post_init__(self) -> None:
# unpack metadata field
......
......@@ -7,7 +7,18 @@ from collections.abc import Callable
from copy import deepcopy
from dataclasses import asdict, dataclass
from inspect import getsource
from typing import Any, Iterator, List, Literal, Tuple, Union
from typing import (
Any,
Dict,
Iterable,
Iterator,
List,
Literal,
Mapping,
Optional,
Tuple,
Union,
)
import datasets
import numpy as np
......@@ -15,12 +26,8 @@ from tqdm import tqdm
from lm_eval import utils
from lm_eval.api import samplers
from lm_eval.api.instance import Instance
from lm_eval.api.metrics import (
bits_per_byte,
mean,
weighted_perplexity,
)
from lm_eval.api.instance import Instance, OutputType
from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
......@@ -47,56 +54,54 @@ eval_logger = logging.getLogger("lm-eval")
@dataclass
class TaskConfig(dict):
# task naming/registry
task: str = None
task_alias: str = None
group: Union[str, list] = None
group_alias: Union[str, list] = None
task: Optional[str] = None
task_alias: Optional[str] = None
group: Optional[Union[str, list]] = None
group_alias: Optional[Union[str, list]] = None
# HF dataset options.
# which dataset to use,
# and what splits for what purpose
dataset_path: str = None
dataset_name: str = None
dataset_kwargs: dict = None
training_split: str = None
validation_split: str = None
test_split: str = None
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
dataset_path: Optional[str] = None
dataset_name: Optional[str] = None
dataset_kwargs: Optional[dict] = None
training_split: Optional[str] = None
validation_split: Optional[str] = None
test_split: Optional[str] = None
fewshot_split: Optional[
str
] = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
# formatting / prompting options.
# see docs/advanced_task_guide.md for more info
process_docs: Callable = None
doc_to_text: Union[Callable, str] = None
doc_to_target: Union[Callable, str] = None
doc_to_choice: Union[Callable, str, dict, list] = None
process_results: Union[Callable, str] = None
use_prompt: str = None
process_docs: Optional[Callable] = None
doc_to_text: Optional[Union[Callable, str]] = None
doc_to_target: Optional[Union[Callable, str]] = None
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
process_results: Optional[Union[Callable, str]] = None
use_prompt: Optional[str] = None
description: str = ""
target_delimiter: str = " "
fewshot_delimiter: str = "\n\n"
fewshot_config: dict = None
fewshot_config: Optional[dict] = None
# runtime configuration options
num_fewshot: int = None
num_fewshot: Optional[int] = None
# scoring options
metric_list: list = None
output_type: Literal[
"loglikelihood",
"loglikelihood_rolling",
"generate_until",
"multiple_choice",
] = "generate_until"
generation_kwargs: dict = None
metric_list: Optional[list] = None
output_type: OutputType = "generate_until"
generation_kwargs: Optional[dict] = None
repeats: int = 1
filter_list: Union[str, list] = None
filter_list: Optional[Union[str, list]] = None
should_decontaminate: bool = False
doc_to_decontamination_query: str = None
metadata: dict = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
doc_to_decontamination_query: Optional[str] = None
metadata: Optional[
dict
] = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
def __post_init__(self) -> None:
if self.generation_kwargs is not None:
if self.output_type != "generate_until":
eval_logger.warning(
raise ValueError(
f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
)
assert self.output_type != "generate_until"
if "temperature" in self.generation_kwargs:
self.generation_kwargs["temperature"] = float(
......@@ -177,23 +182,23 @@ class Task(abc.ABC):
{"question": ..., question, answer)
"""
VERSION = None
VERSION: Optional[Union[int, str]] = None
# The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
# or a path to a custom `datasets` loading script.
DATASET_PATH: str = None
DATASET_PATH: Optional[str] = None
# The name of a subset within `DATASET_PATH`.
DATASET_NAME: str = None
DATASET_NAME: Optional[str] = None
OUTPUT_TYPE: str = None
OUTPUT_TYPE: Optional[OutputType] = None
def __init__(
self,
data_dir=None,
cache_dir=None,
download_mode=None,
config=None,
data_dir: Optional[str] = None,
cache_dir: Optional[str] = None,
download_mode: Optional[datasets.DownloadMode] = None,
config: Optional[Mapping] = None, # Union[dict, TaskConfig]
) -> None:
"""
:param data_dir: str
......@@ -217,15 +222,20 @@ class Task(abc.ABC):
Fresh download and fresh dataset.
"""
self.download(data_dir, cache_dir, download_mode)
self._training_docs = None
self._fewshot_docs = None
self._instances = None
self._training_docs: Optional[list] = None
self._fewshot_docs: Optional[list] = None
self._instances: Optional[List[Instance]] = None
self._config = TaskConfig({**config}) if config else TaskConfig()
self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig()
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
def download(self, data_dir=None, cache_dir=None, download_mode=None) -> None:
def download(
self,
data_dir: Optional[str] = None,
cache_dir: Optional[str] = None,
download_mode=None,
) -> None:
"""Downloads and returns the task dataset.
Override this method to download the dataset from a custom API.
......@@ -259,7 +269,7 @@ class Task(abc.ABC):
)
@property
def config(self):
def config(self) -> TaskConfig:
"""Returns the TaskConfig associated with this class."""
return self._config
......@@ -278,28 +288,28 @@ class Task(abc.ABC):
"""Whether the task has a test set"""
pass
def training_docs(self):
def training_docs(self) -> Iterable:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return []
def validation_docs(self):
def validation_docs(self) -> Iterable:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return []
def test_docs(self):
def test_docs(self) -> Iterable:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return []
def fewshot_docs(self):
def fewshot_docs(self) -> Iterable:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
......@@ -315,7 +325,7 @@ class Task(abc.ABC):
)
return self.test_docs()
def _process_doc(self, doc):
def _process_doc(self, doc: dict) -> dict:
"""
Override this to process (detokenize, strip, replace, etc.) individual
documents. This can be used in a map over documents of a data split.
......@@ -339,11 +349,10 @@ class Task(abc.ABC):
return rnd.sample(self._training_docs, k)
def doc_to_decontamination_query(self, doc) -> None:
print(
def doc_to_decontamination_query(self, doc):
raise NotImplementedError(
"Override doc_to_decontamination_query with document specific decontamination query."
)
assert False
@abc.abstractmethod
def doc_to_text(self, doc):
......@@ -435,7 +444,8 @@ class Task(abc.ABC):
self._instances = flattened_instances
assert len(self._instances) != 0, "task.build_requests() did not find any docs!"
if len(self._instances) == 0:
raise ValueError("task.build_requests() did not find any docs!")
if cache_requests and (not cached_instances or rewrite_requests_cache):
save_to_cache(file_name=cache_key, obj=instances)
......@@ -528,9 +538,10 @@ class Task(abc.ABC):
:returns: str
The fewshot context.
"""
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
if rnd is None:
raise ValueError(
"A `random.Random` generator argument must be provided to `rnd`"
)
description = description if description else ""
......@@ -566,7 +577,7 @@ class Task(abc.ABC):
example = self.doc_to_text(doc)
return description + labeled_examples + example
def apply_filters(self):
def apply_filters(self) -> Optional[List[Instance]]:
"""Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"):
for f in self._filters:
......@@ -628,7 +639,9 @@ class Task(abc.ABC):
elif self.has_validation_docs():
return self.validation_docs()
else:
assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
raise ValueError(
f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
)
def doc_iterator(
self, *, rank: int = 0, limit: Union[int, None] = None, world_size: int = 1
......@@ -649,7 +662,11 @@ class ConfigurableTask(Task):
CONFIG = None
def __init__(
self, data_dir=None, cache_dir=None, download_mode=None, config: dict = None
self,
data_dir=None,
cache_dir=None,
download_mode=None,
config: Optional[dict] = None,
) -> None: # TODO no super() call here
# Get pre-configured attributes
self._config = self.CONFIG
......@@ -672,7 +689,10 @@ class ConfigurableTask(Task):
self.VERSION = self.config.metadata["version"]
if self.config.output_type is not None:
assert self.config.output_type in ALL_OUTPUT_TYPES
if self.config.output_type not in ALL_OUTPUT_TYPES:
raise ValueError(
f"Got invalid output_type '{self.config.output_type}', must be in '{','.join(ALL_OUTPUT_TYPES)}'"
)
self.OUTPUT_TYPE = self.config.output_type
if self.config.dataset_path is not None:
......@@ -699,7 +719,10 @@ class ConfigurableTask(Task):
self._higher_is_better[metric_name] = is_higher_better(metric_name)
else:
for metric_config in self.config.metric_list:
assert "metric" in metric_config
if "metric" not in metric_config:
raise ValueError(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
metric_name = metric_config["metric"]
kwargs = {
key: metric_config[key]
......@@ -844,7 +867,7 @@ class ConfigurableTask(Task):
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
)
def download(self, dataset_kwargs=None) -> None:
def download(self, dataset_kwargs: Optional[Dict[str, Any]] = None) -> None:
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
......@@ -906,7 +929,7 @@ class ConfigurableTask(Task):
return super().fewshot_docs()
@utils.positional_deprecated
def fewshot_context(self, doc, num_fewshot):
def fewshot_context(self, doc: str, num_fewshot: int) -> str:
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
......@@ -970,7 +993,7 @@ class ConfigurableTask(Task):
)
)
def _process_doc(self, doc):
def _process_doc(self, doc: dict) -> dict:
"""
Override this to process (detokenize, strip, replace, etc.) individual
documents. This can be used in a map over documents of a data split.
......@@ -1015,7 +1038,7 @@ class ConfigurableTask(Task):
print(type(doc_to_text))
raise TypeError
def doc_to_target(self, doc: dict) -> Union[int, str, list]:
def doc_to_target(self, doc: Mapping) -> Union[int, str, list]:
if self.prompt is not None:
doc_to_target = self.prompt
else:
......@@ -1190,7 +1213,8 @@ class ConfigurableTask(Task):
# then we are doing mutual info.
# this stores the "dryrun" / unconditional answer loglikelihoods
lls_unconditional = lls[1::2]
assert len(lls_unconditional) == len(choices)
if len(lls_unconditional) != len(choices):
raise ValueError
# and this stores our "regular" conditional loglikelihoods
lls = lls[::2]
......@@ -1352,7 +1376,7 @@ class ConfigurableTask(Task):
class MultipleChoiceTask(Task):
OUTPUT_TYPE: str = "loglikelihood"
OUTPUT_TYPE = "loglikelihood"
def doc_to_target(self, doc: dict) -> str:
return " " + doc["choices"][doc["gold"]]
......@@ -1370,7 +1394,7 @@ class MultipleChoiceTask(Task):
for i, choice in enumerate(doc["choices"])
]
def process_results(self, doc: dict, results: List[Tuple[float, bool]]) -> dict:
def process_results(self, doc: dict, results: Iterable[Tuple[float, bool]]) -> dict:
results = [
res[0] for res in results
] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
......@@ -1405,13 +1429,17 @@ class PerplexityTask(Task):
return False
def fewshot_examples(self, k: int, rnd) -> List:
assert k == 0
if k != 0:
raise ValueError(
"The number of fewshot examples must be 0 for perplexity tasks."
)
return []
def fewshot_context(self, doc: dict, num_fewshot: int) -> Literal[""]:
assert (
num_fewshot == 0
), "The number of fewshot examples must be 0 for perplexity tasks."
if num_fewshot != 0:
raise ValueError(
"The number of fewshot examples must be 0 for perplexity tasks."
)
return ""
......@@ -1431,8 +1459,9 @@ class PerplexityTask(Task):
def doc_to_target(self, doc):
return doc
def construct_requests(self, doc: dict, ctx: Union[str, None], **kwargs):
assert not ctx
def construct_requests(self, doc: dict, ctx: Optional[str], **kwargs):
if bool(ctx):
raise ValueError
return Instance(
request_type=self.OUTPUT_TYPE,
......@@ -1442,7 +1471,7 @@ class PerplexityTask(Task):
**kwargs,
)
def process_results(self, doc: dict, results: float) -> dict:
def process_results(self, doc: dict, results: Tuple[float]) -> dict:
(loglikelihood,) = results
words = self.count_words(self.doc_to_target(doc))
bytes_ = self.count_bytes(self.doc_to_target(doc))
......
import collections
import itertools
import logging
import random
from typing import TYPE_CHECKING, Optional, Union
from collections import defaultdict
from typing import TYPE_CHECKING, List, Optional, Union
import numpy as np
import torch
......@@ -10,6 +10,7 @@ import torch
import lm_eval.api.metrics
import lm_eval.api.registry
import lm_eval.models
from lm_eval.caching.cache import delete_cache
from lm_eval.evaluator_utils import (
consolidate_results,
get_sample_size,
......@@ -20,25 +21,19 @@ from lm_eval.evaluator_utils import (
)
from lm_eval.logging_utils import add_env_info, get_git_commit_hash
from lm_eval.tasks import TaskManager, get_task_dict
from lm_eval.utils import (
eval_logger,
positional_deprecated,
simple_parse_args_string,
)
from lm_eval.utils import eval_logger, positional_deprecated, simple_parse_args_string
if TYPE_CHECKING:
from lm_eval.api.model import LM
from lm_eval.tasks import Task
from lm_eval.caching.cache import delete_cache
@positional_deprecated
def simple_evaluate(
model,
model_args: Optional[Union[str, dict, None]] = None,
tasks=None,
model_args: Optional[Union[str, dict]] = None,
tasks: Optional[List[Union[str, dict, object]]] = None,
num_fewshot: Optional[int] = None,
batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
......@@ -50,11 +45,10 @@ def simple_evaluate(
limit: Optional[Union[int, float]] = None,
bootstrap_iters: int = 100000,
check_integrity: bool = False,
decontamination_ngrams_path=None,
write_out: bool = False,
log_samples: bool = True,
gen_kwargs: str = None,
task_manager: TaskManager = None,
gen_kwargs: Optional[str] = None,
task_manager: Optional[TaskManager] = None,
verbosity: str = "INFO",
predict_only: bool = False,
random_seed: int = 0,
......@@ -136,14 +130,16 @@ def simple_evaluate(
if tasks is None:
tasks = []
assert (
tasks != []
), "No tasks specified, or no tasks found. Please verify the task names."
if len(tasks) == 0:
raise ValueError(
"No tasks specified, or no tasks found. Please verify the task names."
)
if gen_kwargs is not None:
gen_kwargs = simple_parse_args_string(gen_kwargs)
eval_logger.warning(
"generation_kwargs specified through cli, these settings will update set parameters in yaml tasks. Ensure 'do_sample=True' for non-greedy decoding!"
"generation_kwargs specified through cli, these settings will update set parameters in yaml tasks. "
"Ensure 'do_sample=True' for non-greedy decoding!"
)
if gen_kwargs == "":
gen_kwargs = None
......@@ -172,7 +168,8 @@ def simple_evaluate(
},
)
else:
assert isinstance(model, lm_eval.api.model.LM)
if not isinstance(model, lm_eval.api.model.LM):
raise TypeError
lm = model
if use_cache is not None:
......@@ -237,7 +234,6 @@ def simple_evaluate(
cache_requests=cache_requests,
rewrite_requests_cache=rewrite_requests_cache,
bootstrap_iters=bootstrap_iters,
decontamination_ngrams_path=decontamination_ngrams_path,
write_out=write_out,
log_samples=log_samples,
verbosity=verbosity,
......@@ -272,18 +268,14 @@ def simple_evaluate(
return None
decontaminate_suffix = "_decontaminate"
@positional_deprecated
def evaluate(
lm: "LM",
task_dict,
limit: Optional[int] = None,
cache_requests=False,
rewrite_requests_cache=False,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
bootstrap_iters: Optional[int] = 100000,
decontamination_ngrams_path=None,
write_out: bool = False,
log_samples: bool = True,
verbosity: str = "INFO",
......@@ -307,21 +299,21 @@ def evaluate(
"""
eval_logger.setLevel(getattr(logging, f"{verbosity}"))
# decontaminate = decontamination_ngrams_path is not None
# tracks all Instances/requests a model must generate output on.
requests = collections.defaultdict(list)
requests = defaultdict(list)
# stores the amount to pad out reqs per req. type so that
# number of fwd passes per distributed rank is equal
padding_requests = collections.defaultdict(int)
padding_requests = defaultdict(int)
# get lists of group hierarchy and each type of request
task_hierarchy, eval_tasks = get_task_list(task_dict)
if not log_samples:
assert all(
if not all(
"bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
for task_output in eval_tasks
), "log_samples must be True for 'bypass' only tasks"
):
raise ValueError("log_samples must be True for 'bypass' metric-only tasks")
for task_output in eval_tasks:
task: Task = task_output.task
limit = get_sample_size(task, limit)
......@@ -394,7 +386,7 @@ def evaluate(
# # unpack results and sort back in order and return control to Task
# TODO: make it possible to use a different metric per filter
# Pre-process task.instances to group by doc_id
instances_by_doc_id = collections.defaultdict(list)
instances_by_doc_id = defaultdict(list)
for instance in task.instances:
instances_by_doc_id[instance.doc_id].append(instance)
# Sort instances within each group
......@@ -521,10 +513,9 @@ def evaluate(
results[group]["samples"] = sum(sizes)
results_agg = collections.defaultdict(dict)
groups_agg = collections.defaultdict(dict)
results_agg = defaultdict(dict)
groups_agg = defaultdict(dict)
all_tasks_list = list(task_hierarchy.keys())
left_tasks_list = []
while True:
add_tasks_list = list(k for k in results_agg.keys())
left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list)))
......@@ -548,7 +539,7 @@ def evaluate(
results_dict = {
"results": dict(results_agg.items()),
**({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}),
"group_subtasks": {k: v for k, v in reversed(task_hierarchy.items())},
"group_subtasks": dict(reversed(task_hierarchy.items())),
"configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())),
"n-shot": dict(sorted(num_fewshot.items())),
......@@ -564,11 +555,9 @@ def evaluate(
def request_caching_arg_to_dict(cache_requests: str) -> dict:
request_caching_args = {
"cache_requests": (
True if cache_requests == "true" or cache_requests == "refresh" else False
),
"rewrite_requests_cache": True if cache_requests == "refresh" else False,
"delete_requests_cache": True if cache_requests == "delete" else False,
"cache_requests": cache_requests in {"true", "refresh"},
"rewrite_requests_cache": cache_requests == "refresh",
"delete_requests_cache": cache_requests == "delete",
}
return request_caching_args
import abc
import collections
import logging
import os
from functools import partial
from typing import Dict, List, Union
from typing import Dict, List, Mapping, Optional, Union
from lm_eval import utils
from lm_eval.api.task import ConfigurableTask, Task
......@@ -15,7 +14,7 @@ class TaskManager:
"""
def __init__(self, verbosity="INFO", include_path=None) -> None:
def __init__(self, verbosity="INFO", include_path: Optional[str] = None) -> None:
self.verbosity = verbosity
self.include_path = include_path
self.logger = utils.eval_logger
......@@ -26,8 +25,8 @@ class TaskManager:
self.task_group_map = collections.defaultdict(list)
def initialize_tasks(self, include_path: str = None):
"""Creates an dictionary of tasks index.
def initialize_tasks(self, include_path: Optional[str] = None):
"""Creates a dictionary of tasks index.
:param include_path: str = None
An additional path to be searched for tasks
......@@ -59,7 +58,7 @@ class TaskManager:
def match_tasks(self, task_list):
return utils.pattern_match(task_list, self.all_tasks)
def _name_is_registered(self, name):
def _name_is_registered(self, name) -> bool:
if name in self.all_tasks:
return True
return False
......@@ -69,7 +68,7 @@ class TaskManager:
return True
return False
def _name_is_group(self, name):
def _name_is_group(self, name) -> bool:
if self._name_is_registered(name) and (
self.task_index[name]["type"] == "group"
):
......@@ -83,27 +82,29 @@ class TaskManager:
return True
return False
def _config_is_task(self, config):
def _config_is_task(self, config) -> bool:
if ("task" in config) and isinstance(config["task"], str):
return True
return False
def _config_is_group(self, config):
def _config_is_group(self, config) -> bool:
if ("task" in config) and isinstance(config["task"], list):
return True
return False
def _config_is_python_task(self, config):
def _config_is_python_task(self, config) -> bool:
if "class" in config:
return True
return False
def _get_yaml_path(self, name):
assert name in self.task_index
if name not in self.task_index:
raise ValueError
return self.task_index[name]["yaml_path"]
def _get_config(self, name):
assert name in self.task_index
if name not in self.task_index:
raise ValueError
yaml_path = self._get_yaml_path(name)
if yaml_path == -1:
return {}
......@@ -111,7 +112,8 @@ class TaskManager:
return utils.load_yaml_config(yaml_path, mode="full")
def _get_tasklist(self, name):
assert self._name_is_task(name) is False
if self._name_is_task(name):
raise ValueError
return self.task_index[name]["task"]
def _process_alias(self, config, group=None):
......@@ -125,14 +127,15 @@ class TaskManager:
def _load_individual_task_or_group(
self,
name_or_config: Union[str, dict] = None,
parent_name: str = None,
update_config: dict = None,
yaml_path: str = None,
) -> ConfigurableTask:
name_or_config: Optional[Union[str, dict]] = None,
parent_name: Optional[str] = None,
update_config: Optional[dict] = None,
yaml_path: Optional[str] = None,
) -> Mapping:
def load_task(config, task, group=None, yaml_path=None):
if "include" in config:
assert yaml_path is not None
if yaml_path is None:
raise ValueError
config.update(
utils.load_yaml_config(
yaml_path,
......@@ -166,7 +169,7 @@ class TaskManager:
# This checks if we're at the root.
if parent_name is None:
group_config = self._get_config(name_or_config)
if set(group_config.keys()) > set(["task", "group"]):
if set(group_config.keys()) > {"task", "group"}:
update_config = {
k: v
for k, v in group_config.items()
......@@ -228,7 +231,7 @@ class TaskManager:
else:
group_name = name_or_config["group"]
subtask_list = name_or_config["task"]
if set(name_or_config.keys()) > set(["task", "group"]):
if set(name_or_config.keys()) > {"task", "group"}:
update_config = {
k: v
for k, v in name_or_config.items()
......@@ -251,7 +254,7 @@ class TaskManager:
}
return all_subtasks
def load_task_or_group(self, task_list: Union[str, list] = None) -> dict:
def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict:
"""Loads a dictionary of task objects from a list
:param task_list: Union[str, list] = None
......@@ -272,7 +275,7 @@ class TaskManager:
return self._load_individual_task_or_group(config)
def _get_task_and_group(self, task_dir: str):
"""Creates an dictionary of tasks index with the following metadata,
"""Creates a dictionary of tasks index with the following metadata,
- `type`, that can be either `task`, `python_task`, or `group`.
`task` refer to regular task configs, `python_task` are special
yaml files that only consists of `task` and `class` parameters.
......@@ -358,7 +361,8 @@ def include_path(task_dir):
logger.setLevel(getattr(logging, "INFO"))
logger.info(
"To still use tasks loaded from args.include_path,"
"see an example of the new TaskManager API in https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md#external-library-usage"
"see an example of the new TaskManager API in "
"https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md#external-library-usage"
)
return 0
......@@ -397,7 +401,8 @@ def get_task_name_from_object(task_object):
def get_task_dict(
task_name_list: List[Union[str, Dict, Task]], task_manager: TaskManager = None
task_name_list: List[Union[str, Dict, Task]],
task_manager: Optional[TaskManager] = None,
):
"""Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
......@@ -442,9 +447,10 @@ def get_task_dict(
get_task_name_from_object(task_element): task_element,
}
assert set(task_name_from_string_dict.keys()).isdisjoint(
if not set(task_name_from_string_dict.keys()).isdisjoint(
set(task_name_from_object_dict.keys())
)
):
raise ValueError
return {
**task_name_from_string_dict,
......
......@@ -95,9 +95,9 @@ all = [
[tool.ruff.lint]
extend-select = ["I"]
[tool.ruff.isort]
[tool.ruff.lint.isort]
lines-after-imports = 2
known-first-party = ["lm_eval"]
[tool.ruff.extend-per-file-ignores]
[tool.ruff.lint.extend-per-file-ignores]
"__init__.py" = ["F401","F402","F403"]
......@@ -3,7 +3,7 @@ import random
import transformers
from lm_eval import evaluator, tasks
from lm_eval.base import LM
from lm_eval.api.model import LM
class DryrunLM(LM):
......@@ -53,13 +53,12 @@ def main():
values = []
for taskname in task_list.split(","):
lm.tokencost = 0
evaluator.evaluate(
evaluator.simple_evaluate(
lm=lm,
task_dict={taskname: tasks.get_task(taskname)()},
num_fewshot=0,
limit=None,
bootstrap_iters=10,
description_dict=None,
)
print(taskname, lm.tokencost)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment