Unverified Commit 4ee1b386 authored by LSinev's avatar LSinev Committed by GitHub
Browse files

Cleanup and fixes (Task, Instance, and a little bit of *evaluate) (#1533)



* Remove unused `decontamination_ngrams_path` and all mentions (still no alternative path provided)

* Fix improper import of LM and usage of evaluator in one of scripts

* update type hints in instance and task api

* raising errors in task.py instead of asserts

* Fix warnings from ruff

* raising errors in __main__.py instead of asserts

* raising errors in tasks/__init__.py instead of asserts

* raising errors in evaluator.py instead of asserts

* evaluator: update type hints and remove unused variables in code

* Update lm_eval/__main__.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* Update lm_eval/__main__.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* Update lm_eval/api/task.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* Update lm_eval/api/task.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* Update lm_eval/api/task.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* Update lm_eval/evaluator.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* pre-commit induced fixes

---------
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent 02705057
...@@ -2,15 +2,14 @@ ...@@ -2,15 +2,14 @@
## Usage ## Usage
Simply add a "--decontamination_ngrams_path" when running \__main\__.py. The provided directory should contain The provided directory should contain
the ngram files and info.json produced in "Pile Ngram Generation" further down. the ngram files and info.json produced in "Pile Ngram Generation" further down.
```bash ```bash
python -m lm_eval \ python -m lm_eval \
--model gpt2 \ --model gpt2 \
--device 0 \ --device 0 \
--tasks sciq \ --tasks sciq
--decontamination_ngrams_path path/containing/training/set/ngrams
``` ```
## Background ## Background
...@@ -70,5 +69,3 @@ python -m scripts/clean_training_data/compress_and_package \ ...@@ -70,5 +69,3 @@ python -m scripts/clean_training_data/compress_and_package \
-output path/to/final/directory \ -output path/to/final/directory \
-procs 8 -procs 8
``` ```
Congratulations, the final directory can now be passed to lm-evaulation-harness with the "--decontamination_ngrams_path" argument.
...@@ -36,8 +36,6 @@ This mode supports a number of command-line arguments, the details of which can ...@@ -36,8 +36,6 @@ This mode supports a number of command-line arguments, the details of which can
- `--cache_requests` : Can be "true", "refresh", or "delete". "true" means that the cache should be used. "refresh" means that you wish to regenerate the cache, which you should run if you change your dataset configuration for a given task. "delete" will delete the cache. Cached files are stored under lm_eval/cache/.cache unless you specify a different path via the environment variable: `LM_HARNESS_CACHE_PATH`. e.g. `LM_HARNESS_CACHE_PATH=~/Documents/cache_for_lm_harness`. - `--cache_requests` : Can be "true", "refresh", or "delete". "true" means that the cache should be used. "refresh" means that you wish to regenerate the cache, which you should run if you change your dataset configuration for a given task. "delete" will delete the cache. Cached files are stored under lm_eval/cache/.cache unless you specify a different path via the environment variable: `LM_HARNESS_CACHE_PATH`. e.g. `LM_HARNESS_CACHE_PATH=~/Documents/cache_for_lm_harness`.
- `--decontamination_ngrams_path` : Deprecated, see (this commit)[https://github.com/EleutherAI/lm-evaluation-harness/commit/00209e10f6e27edf5d766145afaf894079b5fe10] or older for a working decontamination-checker tool.
- `--check_integrity` : If this flag is used, the library tests for each task selected are run to confirm task integrity. - `--check_integrity` : If this flag is used, the library tests for each task selected are run to confirm task integrity.
- `--write_out` : Used for diagnostic purposes to observe the format of task documents passed to a model. If this flag is used, then prints the prompt and gold target string for the first document of each task. - `--write_out` : Used for diagnostic purposes to observe the format of task documents passed to a model. If this flag is used, then prints the prompt and gold target string for the first document of each task.
......
...@@ -17,6 +17,9 @@ from lm_eval.tasks import TaskManager, include_path, initialize_tasks ...@@ -17,6 +17,9 @@ from lm_eval.tasks import TaskManager, include_path, initialize_tasks
from lm_eval.utils import make_table, simple_parse_args_string from lm_eval.utils import make_table, simple_parse_args_string
DEFAULT_RESULTS_FILE = "results.json"
def _handle_non_serializable(o): def _handle_non_serializable(o):
if isinstance(o, np.int64) or isinstance(o, np.int32): if isinstance(o, np.int64) or isinstance(o, np.int32):
return int(o) return int(o)
...@@ -127,7 +130,6 @@ def parse_eval_args() -> argparse.Namespace: ...@@ -127,7 +130,6 @@ def parse_eval_args() -> argparse.Namespace:
choices=["true", "refresh", "delete"], choices=["true", "refresh", "delete"],
help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.", help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
) )
parser.add_argument("--decontamination_ngrams_path", default=None) # TODO: not used
parser.add_argument( parser.add_argument(
"--check_integrity", "--check_integrity",
action="store_true", action="store_true",
...@@ -226,7 +228,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -226,7 +228,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if args.predict_only: if args.predict_only:
args.log_samples = True args.log_samples = True
if (args.log_samples or args.predict_only) and not args.output_path: if (args.log_samples or args.predict_only) and not args.output_path:
assert args.output_path, "Specify --output_path" raise ValueError(
"Specify --output_path if providing --log_samples or --predict_only"
)
initialize_tasks(args.verbosity) initialize_tasks(args.verbosity)
task_manager = TaskManager(args.verbosity, include_path=args.include_path) task_manager = TaskManager(args.verbosity, include_path=args.include_path)
...@@ -281,12 +285,13 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -281,12 +285,13 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if args.output_path: if args.output_path:
path = Path(args.output_path) path = Path(args.output_path)
# check if file or 'dir/results.json' exists # check if file or 'dir/results.json' exists
if path.is_file() or Path(args.output_path).joinpath("results.json").is_file(): if path.is_file():
raise FileExistsError(f"File already exists at {path}")
output_path_file = path.joinpath(DEFAULT_RESULTS_FILE)
if output_path_file.is_file():
eval_logger.warning( eval_logger.warning(
f"File already exists at {path}. Results will be overwritten." f"File {output_path_file} already exists. Results will be overwritten."
) )
output_path_file = path.joinpath("results.json")
assert not path.is_file(), "File already exists"
# if path json then get parent dir # if path json then get parent dir
elif path.suffix in (".json", ".jsonl"): elif path.suffix in (".json", ".jsonl"):
output_path_file = path output_path_file = path
...@@ -294,7 +299,6 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -294,7 +299,6 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
path = path.parent path = path.parent
else: else:
path.mkdir(parents=True, exist_ok=True) path.mkdir(parents=True, exist_ok=True)
output_path_file = path.joinpath("results.json")
# Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args # Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
if args.trust_remote_code: if args.trust_remote_code:
...@@ -321,17 +325,16 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -321,17 +325,16 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
device=args.device, device=args.device,
use_cache=args.use_cache, use_cache=args.use_cache,
limit=args.limit, limit=args.limit,
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity, check_integrity=args.check_integrity,
write_out=args.write_out, write_out=args.write_out,
log_samples=args.log_samples, log_samples=args.log_samples,
gen_kwargs=args.gen_kwargs, gen_kwargs=args.gen_kwargs,
task_manager=task_manager, task_manager=task_manager,
predict_only=args.predict_only, predict_only=args.predict_only,
**request_caching_args,
random_seed=args.seed[0], random_seed=args.seed[0],
numpy_random_seed=args.seed[1], numpy_random_seed=args.seed[1],
torch_random_seed=args.seed[2], torch_random_seed=args.seed[2],
**request_caching_args,
) )
if results is not None: if results is not None:
......
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Tuple from typing import Literal, Optional, Tuple
OutputType = Literal[
"loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice"
]
@dataclass @dataclass
class Instance: class Instance:
request_type: Literal[ request_type: OutputType
"loglikelihood",
"loglikelihood_rolling",
"generate_until",
"multiple_choice",
]
doc: dict doc: dict
arguments: tuple arguments: tuple
idx: int idx: int
metadata: Tuple[str, int, int] = field( metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
default_factory=lambda: (None, None, None) default_factory=lambda: (None, None, None)
) # TODO: better typehints here )
resps: list = field(default_factory=list) resps: list = field(default_factory=list)
filtered_resps: dict = field(default_factory=dict) filtered_resps: dict = field(default_factory=dict)
# initialized after init # initialized after init
task_name: str = None task_name: Optional[str] = None
doc_id: str = None doc_id: Optional[int] = None
repeats: str = None repeats: Optional[int] = None
def __post_init__(self) -> None: def __post_init__(self) -> None:
# unpack metadata field # unpack metadata field
......
...@@ -7,7 +7,18 @@ from collections.abc import Callable ...@@ -7,7 +7,18 @@ from collections.abc import Callable
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from inspect import getsource from inspect import getsource
from typing import Any, Iterator, List, Literal, Tuple, Union from typing import (
Any,
Dict,
Iterable,
Iterator,
List,
Literal,
Mapping,
Optional,
Tuple,
Union,
)
import datasets import datasets
import numpy as np import numpy as np
...@@ -15,12 +26,8 @@ from tqdm import tqdm ...@@ -15,12 +26,8 @@ from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
from lm_eval.api import samplers from lm_eval.api import samplers
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance, OutputType
from lm_eval.api.metrics import ( from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
bits_per_byte,
mean,
weighted_perplexity,
)
from lm_eval.api.registry import ( from lm_eval.api.registry import (
AGGREGATION_REGISTRY, AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY, DEFAULT_METRIC_REGISTRY,
...@@ -47,56 +54,54 @@ eval_logger = logging.getLogger("lm-eval") ...@@ -47,56 +54,54 @@ eval_logger = logging.getLogger("lm-eval")
@dataclass @dataclass
class TaskConfig(dict): class TaskConfig(dict):
# task naming/registry # task naming/registry
task: str = None task: Optional[str] = None
task_alias: str = None task_alias: Optional[str] = None
group: Union[str, list] = None group: Optional[Union[str, list]] = None
group_alias: Union[str, list] = None group_alias: Optional[Union[str, list]] = None
# HF dataset options. # HF dataset options.
# which dataset to use, # which dataset to use,
# and what splits for what purpose # and what splits for what purpose
dataset_path: str = None dataset_path: Optional[str] = None
dataset_name: str = None dataset_name: Optional[str] = None
dataset_kwargs: dict = None dataset_kwargs: Optional[dict] = None
training_split: str = None training_split: Optional[str] = None
validation_split: str = None validation_split: Optional[str] = None
test_split: str = None test_split: Optional[str] = None
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?) fewshot_split: Optional[
str
] = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
# formatting / prompting options. # formatting / prompting options.
# see docs/advanced_task_guide.md for more info # see docs/advanced_task_guide.md for more info
process_docs: Callable = None process_docs: Optional[Callable] = None
doc_to_text: Union[Callable, str] = None doc_to_text: Optional[Union[Callable, str]] = None
doc_to_target: Union[Callable, str] = None doc_to_target: Optional[Union[Callable, str]] = None
doc_to_choice: Union[Callable, str, dict, list] = None doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
process_results: Union[Callable, str] = None process_results: Optional[Union[Callable, str]] = None
use_prompt: str = None use_prompt: Optional[str] = None
description: str = "" description: str = ""
target_delimiter: str = " " target_delimiter: str = " "
fewshot_delimiter: str = "\n\n" fewshot_delimiter: str = "\n\n"
fewshot_config: dict = None fewshot_config: Optional[dict] = None
# runtime configuration options # runtime configuration options
num_fewshot: int = None num_fewshot: Optional[int] = None
# scoring options # scoring options
metric_list: list = None metric_list: Optional[list] = None
output_type: Literal[ output_type: OutputType = "generate_until"
"loglikelihood", generation_kwargs: Optional[dict] = None
"loglikelihood_rolling",
"generate_until",
"multiple_choice",
] = "generate_until"
generation_kwargs: dict = None
repeats: int = 1 repeats: int = 1
filter_list: Union[str, list] = None filter_list: Optional[Union[str, list]] = None
should_decontaminate: bool = False should_decontaminate: bool = False
doc_to_decontamination_query: str = None doc_to_decontamination_query: Optional[str] = None
metadata: dict = None # by default, not used in the code. allows for users to pass arbitrary info to tasks metadata: Optional[
dict
] = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.generation_kwargs is not None: if self.generation_kwargs is not None:
if self.output_type != "generate_until": if self.output_type != "generate_until":
eval_logger.warning( raise ValueError(
f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!" f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
) )
assert self.output_type != "generate_until"
if "temperature" in self.generation_kwargs: if "temperature" in self.generation_kwargs:
self.generation_kwargs["temperature"] = float( self.generation_kwargs["temperature"] = float(
...@@ -177,23 +182,23 @@ class Task(abc.ABC): ...@@ -177,23 +182,23 @@ class Task(abc.ABC):
{"question": ..., question, answer) {"question": ..., question, answer)
""" """
VERSION = None VERSION: Optional[Union[int, str]] = None
# The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
# or a path to a custom `datasets` loading script. # or a path to a custom `datasets` loading script.
DATASET_PATH: str = None DATASET_PATH: Optional[str] = None
# The name of a subset within `DATASET_PATH`. # The name of a subset within `DATASET_PATH`.
DATASET_NAME: str = None DATASET_NAME: Optional[str] = None
OUTPUT_TYPE: str = None OUTPUT_TYPE: Optional[OutputType] = None
def __init__( def __init__(
self, self,
data_dir=None, data_dir: Optional[str] = None,
cache_dir=None, cache_dir: Optional[str] = None,
download_mode=None, download_mode: Optional[datasets.DownloadMode] = None,
config=None, config: Optional[Mapping] = None, # Union[dict, TaskConfig]
) -> None: ) -> None:
""" """
:param data_dir: str :param data_dir: str
...@@ -217,15 +222,20 @@ class Task(abc.ABC): ...@@ -217,15 +222,20 @@ class Task(abc.ABC):
Fresh download and fresh dataset. Fresh download and fresh dataset.
""" """
self.download(data_dir, cache_dir, download_mode) self.download(data_dir, cache_dir, download_mode)
self._training_docs = None self._training_docs: Optional[list] = None
self._fewshot_docs = None self._fewshot_docs: Optional[list] = None
self._instances = None self._instances: Optional[List[Instance]] = None
self._config = TaskConfig({**config}) if config else TaskConfig() self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig()
self._filters = [build_filter_ensemble("none", [["take_first", None]])] self._filters = [build_filter_ensemble("none", [["take_first", None]])]
def download(self, data_dir=None, cache_dir=None, download_mode=None) -> None: def download(
self,
data_dir: Optional[str] = None,
cache_dir: Optional[str] = None,
download_mode=None,
) -> None:
"""Downloads and returns the task dataset. """Downloads and returns the task dataset.
Override this method to download the dataset from a custom API. Override this method to download the dataset from a custom API.
...@@ -259,7 +269,7 @@ class Task(abc.ABC): ...@@ -259,7 +269,7 @@ class Task(abc.ABC):
) )
@property @property
def config(self): def config(self) -> TaskConfig:
"""Returns the TaskConfig associated with this class.""" """Returns the TaskConfig associated with this class."""
return self._config return self._config
...@@ -278,28 +288,28 @@ class Task(abc.ABC): ...@@ -278,28 +288,28 @@ class Task(abc.ABC):
"""Whether the task has a test set""" """Whether the task has a test set"""
pass pass
def training_docs(self): def training_docs(self) -> Iterable:
""" """
:return: Iterable[obj] :return: Iterable[obj]
A iterable of any object, that doc_to_text can handle A iterable of any object, that doc_to_text can handle
""" """
return [] return []
def validation_docs(self): def validation_docs(self) -> Iterable:
""" """
:return: Iterable[obj] :return: Iterable[obj]
A iterable of any object, that doc_to_text can handle A iterable of any object, that doc_to_text can handle
""" """
return [] return []
def test_docs(self): def test_docs(self) -> Iterable:
""" """
:return: Iterable[obj] :return: Iterable[obj]
A iterable of any object, that doc_to_text can handle A iterable of any object, that doc_to_text can handle
""" """
return [] return []
def fewshot_docs(self): def fewshot_docs(self) -> Iterable:
""" """
:return: Iterable[obj] :return: Iterable[obj]
A iterable of any object, that doc_to_text can handle A iterable of any object, that doc_to_text can handle
...@@ -315,7 +325,7 @@ class Task(abc.ABC): ...@@ -315,7 +325,7 @@ class Task(abc.ABC):
) )
return self.test_docs() return self.test_docs()
def _process_doc(self, doc): def _process_doc(self, doc: dict) -> dict:
""" """
Override this to process (detokenize, strip, replace, etc.) individual Override this to process (detokenize, strip, replace, etc.) individual
documents. This can be used in a map over documents of a data split. documents. This can be used in a map over documents of a data split.
...@@ -339,11 +349,10 @@ class Task(abc.ABC): ...@@ -339,11 +349,10 @@ class Task(abc.ABC):
return rnd.sample(self._training_docs, k) return rnd.sample(self._training_docs, k)
def doc_to_decontamination_query(self, doc) -> None: def doc_to_decontamination_query(self, doc):
print( raise NotImplementedError(
"Override doc_to_decontamination_query with document specific decontamination query." "Override doc_to_decontamination_query with document specific decontamination query."
) )
assert False
@abc.abstractmethod @abc.abstractmethod
def doc_to_text(self, doc): def doc_to_text(self, doc):
...@@ -435,7 +444,8 @@ class Task(abc.ABC): ...@@ -435,7 +444,8 @@ class Task(abc.ABC):
self._instances = flattened_instances self._instances = flattened_instances
assert len(self._instances) != 0, "task.build_requests() did not find any docs!" if len(self._instances) == 0:
raise ValueError("task.build_requests() did not find any docs!")
if cache_requests and (not cached_instances or rewrite_requests_cache): if cache_requests and (not cached_instances or rewrite_requests_cache):
save_to_cache(file_name=cache_key, obj=instances) save_to_cache(file_name=cache_key, obj=instances)
...@@ -528,9 +538,10 @@ class Task(abc.ABC): ...@@ -528,9 +538,10 @@ class Task(abc.ABC):
:returns: str :returns: str
The fewshot context. The fewshot context.
""" """
assert ( if rnd is None:
rnd is not None raise ValueError(
), "A `random.Random` generator argument must be provided to `rnd`" "A `random.Random` generator argument must be provided to `rnd`"
)
description = description if description else "" description = description if description else ""
...@@ -566,7 +577,7 @@ class Task(abc.ABC): ...@@ -566,7 +577,7 @@ class Task(abc.ABC):
example = self.doc_to_text(doc) example = self.doc_to_text(doc)
return description + labeled_examples + example return description + labeled_examples + example
def apply_filters(self): def apply_filters(self) -> Optional[List[Instance]]:
"""Iterates over FilterEnsembles and applies them to instances""" """Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"): if hasattr(self, "_filters"):
for f in self._filters: for f in self._filters:
...@@ -628,7 +639,9 @@ class Task(abc.ABC): ...@@ -628,7 +639,9 @@ class Task(abc.ABC):
elif self.has_validation_docs(): elif self.has_validation_docs():
return self.validation_docs() return self.validation_docs()
else: else:
assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!" raise ValueError(
f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
)
def doc_iterator( def doc_iterator(
self, *, rank: int = 0, limit: Union[int, None] = None, world_size: int = 1 self, *, rank: int = 0, limit: Union[int, None] = None, world_size: int = 1
...@@ -649,7 +662,11 @@ class ConfigurableTask(Task): ...@@ -649,7 +662,11 @@ class ConfigurableTask(Task):
CONFIG = None CONFIG = None
def __init__( def __init__(
self, data_dir=None, cache_dir=None, download_mode=None, config: dict = None self,
data_dir=None,
cache_dir=None,
download_mode=None,
config: Optional[dict] = None,
) -> None: # TODO no super() call here ) -> None: # TODO no super() call here
# Get pre-configured attributes # Get pre-configured attributes
self._config = self.CONFIG self._config = self.CONFIG
...@@ -672,7 +689,10 @@ class ConfigurableTask(Task): ...@@ -672,7 +689,10 @@ class ConfigurableTask(Task):
self.VERSION = self.config.metadata["version"] self.VERSION = self.config.metadata["version"]
if self.config.output_type is not None: if self.config.output_type is not None:
assert self.config.output_type in ALL_OUTPUT_TYPES if self.config.output_type not in ALL_OUTPUT_TYPES:
raise ValueError(
f"Got invalid output_type '{self.config.output_type}', must be in '{','.join(ALL_OUTPUT_TYPES)}'"
)
self.OUTPUT_TYPE = self.config.output_type self.OUTPUT_TYPE = self.config.output_type
if self.config.dataset_path is not None: if self.config.dataset_path is not None:
...@@ -699,7 +719,10 @@ class ConfigurableTask(Task): ...@@ -699,7 +719,10 @@ class ConfigurableTask(Task):
self._higher_is_better[metric_name] = is_higher_better(metric_name) self._higher_is_better[metric_name] = is_higher_better(metric_name)
else: else:
for metric_config in self.config.metric_list: for metric_config in self.config.metric_list:
assert "metric" in metric_config if "metric" not in metric_config:
raise ValueError(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
metric_name = metric_config["metric"] metric_name = metric_config["metric"]
kwargs = { kwargs = {
key: metric_config[key] key: metric_config[key]
...@@ -844,7 +867,7 @@ class ConfigurableTask(Task): ...@@ -844,7 +867,7 @@ class ConfigurableTask(Task):
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace' f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
) )
def download(self, dataset_kwargs=None) -> None: def download(self, dataset_kwargs: Optional[Dict[str, Any]] = None) -> None:
self.dataset = datasets.load_dataset( self.dataset = datasets.load_dataset(
path=self.DATASET_PATH, path=self.DATASET_PATH,
name=self.DATASET_NAME, name=self.DATASET_NAME,
...@@ -906,7 +929,7 @@ class ConfigurableTask(Task): ...@@ -906,7 +929,7 @@ class ConfigurableTask(Task):
return super().fewshot_docs() return super().fewshot_docs()
@utils.positional_deprecated @utils.positional_deprecated
def fewshot_context(self, doc, num_fewshot): def fewshot_context(self, doc: str, num_fewshot: int) -> str:
"""Returns a fewshot context string that is made up of a prepended description """Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example. (if provided), the `num_fewshot` number of examples, and an appended prompt example.
...@@ -970,7 +993,7 @@ class ConfigurableTask(Task): ...@@ -970,7 +993,7 @@ class ConfigurableTask(Task):
) )
) )
def _process_doc(self, doc): def _process_doc(self, doc: dict) -> dict:
""" """
Override this to process (detokenize, strip, replace, etc.) individual Override this to process (detokenize, strip, replace, etc.) individual
documents. This can be used in a map over documents of a data split. documents. This can be used in a map over documents of a data split.
...@@ -1015,7 +1038,7 @@ class ConfigurableTask(Task): ...@@ -1015,7 +1038,7 @@ class ConfigurableTask(Task):
print(type(doc_to_text)) print(type(doc_to_text))
raise TypeError raise TypeError
def doc_to_target(self, doc: dict) -> Union[int, str, list]: def doc_to_target(self, doc: Mapping) -> Union[int, str, list]:
if self.prompt is not None: if self.prompt is not None:
doc_to_target = self.prompt doc_to_target = self.prompt
else: else:
...@@ -1190,7 +1213,8 @@ class ConfigurableTask(Task): ...@@ -1190,7 +1213,8 @@ class ConfigurableTask(Task):
# then we are doing mutual info. # then we are doing mutual info.
# this stores the "dryrun" / unconditional answer loglikelihoods # this stores the "dryrun" / unconditional answer loglikelihoods
lls_unconditional = lls[1::2] lls_unconditional = lls[1::2]
assert len(lls_unconditional) == len(choices) if len(lls_unconditional) != len(choices):
raise ValueError
# and this stores our "regular" conditional loglikelihoods # and this stores our "regular" conditional loglikelihoods
lls = lls[::2] lls = lls[::2]
...@@ -1352,7 +1376,7 @@ class ConfigurableTask(Task): ...@@ -1352,7 +1376,7 @@ class ConfigurableTask(Task):
class MultipleChoiceTask(Task): class MultipleChoiceTask(Task):
OUTPUT_TYPE: str = "loglikelihood" OUTPUT_TYPE = "loglikelihood"
def doc_to_target(self, doc: dict) -> str: def doc_to_target(self, doc: dict) -> str:
return " " + doc["choices"][doc["gold"]] return " " + doc["choices"][doc["gold"]]
...@@ -1370,7 +1394,7 @@ class MultipleChoiceTask(Task): ...@@ -1370,7 +1394,7 @@ class MultipleChoiceTask(Task):
for i, choice in enumerate(doc["choices"]) for i, choice in enumerate(doc["choices"])
] ]
def process_results(self, doc: dict, results: List[Tuple[float, bool]]) -> dict: def process_results(self, doc: dict, results: Iterable[Tuple[float, bool]]) -> dict:
results = [ results = [
res[0] for res in results res[0] for res in results
] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere? ] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
...@@ -1405,13 +1429,17 @@ class PerplexityTask(Task): ...@@ -1405,13 +1429,17 @@ class PerplexityTask(Task):
return False return False
def fewshot_examples(self, k: int, rnd) -> List: def fewshot_examples(self, k: int, rnd) -> List:
assert k == 0 if k != 0:
raise ValueError(
"The number of fewshot examples must be 0 for perplexity tasks."
)
return [] return []
def fewshot_context(self, doc: dict, num_fewshot: int) -> Literal[""]: def fewshot_context(self, doc: dict, num_fewshot: int) -> Literal[""]:
assert ( if num_fewshot != 0:
num_fewshot == 0 raise ValueError(
), "The number of fewshot examples must be 0 for perplexity tasks." "The number of fewshot examples must be 0 for perplexity tasks."
)
return "" return ""
...@@ -1431,8 +1459,9 @@ class PerplexityTask(Task): ...@@ -1431,8 +1459,9 @@ class PerplexityTask(Task):
def doc_to_target(self, doc): def doc_to_target(self, doc):
return doc return doc
def construct_requests(self, doc: dict, ctx: Union[str, None], **kwargs): def construct_requests(self, doc: dict, ctx: Optional[str], **kwargs):
assert not ctx if bool(ctx):
raise ValueError
return Instance( return Instance(
request_type=self.OUTPUT_TYPE, request_type=self.OUTPUT_TYPE,
...@@ -1442,7 +1471,7 @@ class PerplexityTask(Task): ...@@ -1442,7 +1471,7 @@ class PerplexityTask(Task):
**kwargs, **kwargs,
) )
def process_results(self, doc: dict, results: float) -> dict: def process_results(self, doc: dict, results: Tuple[float]) -> dict:
(loglikelihood,) = results (loglikelihood,) = results
words = self.count_words(self.doc_to_target(doc)) words = self.count_words(self.doc_to_target(doc))
bytes_ = self.count_bytes(self.doc_to_target(doc)) bytes_ = self.count_bytes(self.doc_to_target(doc))
......
import collections
import itertools import itertools
import logging import logging
import random import random
from typing import TYPE_CHECKING, Optional, Union from collections import defaultdict
from typing import TYPE_CHECKING, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -10,6 +10,7 @@ import torch ...@@ -10,6 +10,7 @@ import torch
import lm_eval.api.metrics import lm_eval.api.metrics
import lm_eval.api.registry import lm_eval.api.registry
import lm_eval.models import lm_eval.models
from lm_eval.caching.cache import delete_cache
from lm_eval.evaluator_utils import ( from lm_eval.evaluator_utils import (
consolidate_results, consolidate_results,
get_sample_size, get_sample_size,
...@@ -20,25 +21,19 @@ from lm_eval.evaluator_utils import ( ...@@ -20,25 +21,19 @@ from lm_eval.evaluator_utils import (
) )
from lm_eval.logging_utils import add_env_info, get_git_commit_hash from lm_eval.logging_utils import add_env_info, get_git_commit_hash
from lm_eval.tasks import TaskManager, get_task_dict from lm_eval.tasks import TaskManager, get_task_dict
from lm_eval.utils import ( from lm_eval.utils import eval_logger, positional_deprecated, simple_parse_args_string
eval_logger,
positional_deprecated,
simple_parse_args_string,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.tasks import Task from lm_eval.tasks import Task
from lm_eval.caching.cache import delete_cache
@positional_deprecated @positional_deprecated
def simple_evaluate( def simple_evaluate(
model, model,
model_args: Optional[Union[str, dict, None]] = None, model_args: Optional[Union[str, dict]] = None,
tasks=None, tasks: Optional[List[Union[str, dict, object]]] = None,
num_fewshot: Optional[int] = None, num_fewshot: Optional[int] = None,
batch_size: Optional[int] = None, batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None,
...@@ -50,11 +45,10 @@ def simple_evaluate( ...@@ -50,11 +45,10 @@ def simple_evaluate(
limit: Optional[Union[int, float]] = None, limit: Optional[Union[int, float]] = None,
bootstrap_iters: int = 100000, bootstrap_iters: int = 100000,
check_integrity: bool = False, check_integrity: bool = False,
decontamination_ngrams_path=None,
write_out: bool = False, write_out: bool = False,
log_samples: bool = True, log_samples: bool = True,
gen_kwargs: str = None, gen_kwargs: Optional[str] = None,
task_manager: TaskManager = None, task_manager: Optional[TaskManager] = None,
verbosity: str = "INFO", verbosity: str = "INFO",
predict_only: bool = False, predict_only: bool = False,
random_seed: int = 0, random_seed: int = 0,
...@@ -136,14 +130,16 @@ def simple_evaluate( ...@@ -136,14 +130,16 @@ def simple_evaluate(
if tasks is None: if tasks is None:
tasks = [] tasks = []
assert ( if len(tasks) == 0:
tasks != [] raise ValueError(
), "No tasks specified, or no tasks found. Please verify the task names." "No tasks specified, or no tasks found. Please verify the task names."
)
if gen_kwargs is not None: if gen_kwargs is not None:
gen_kwargs = simple_parse_args_string(gen_kwargs) gen_kwargs = simple_parse_args_string(gen_kwargs)
eval_logger.warning( eval_logger.warning(
"generation_kwargs specified through cli, these settings will update set parameters in yaml tasks. Ensure 'do_sample=True' for non-greedy decoding!" "generation_kwargs specified through cli, these settings will update set parameters in yaml tasks. "
"Ensure 'do_sample=True' for non-greedy decoding!"
) )
if gen_kwargs == "": if gen_kwargs == "":
gen_kwargs = None gen_kwargs = None
...@@ -172,7 +168,8 @@ def simple_evaluate( ...@@ -172,7 +168,8 @@ def simple_evaluate(
}, },
) )
else: else:
assert isinstance(model, lm_eval.api.model.LM) if not isinstance(model, lm_eval.api.model.LM):
raise TypeError
lm = model lm = model
if use_cache is not None: if use_cache is not None:
...@@ -237,7 +234,6 @@ def simple_evaluate( ...@@ -237,7 +234,6 @@ def simple_evaluate(
cache_requests=cache_requests, cache_requests=cache_requests,
rewrite_requests_cache=rewrite_requests_cache, rewrite_requests_cache=rewrite_requests_cache,
bootstrap_iters=bootstrap_iters, bootstrap_iters=bootstrap_iters,
decontamination_ngrams_path=decontamination_ngrams_path,
write_out=write_out, write_out=write_out,
log_samples=log_samples, log_samples=log_samples,
verbosity=verbosity, verbosity=verbosity,
...@@ -272,18 +268,14 @@ def simple_evaluate( ...@@ -272,18 +268,14 @@ def simple_evaluate(
return None return None
decontaminate_suffix = "_decontaminate"
@positional_deprecated @positional_deprecated
def evaluate( def evaluate(
lm: "LM", lm: "LM",
task_dict, task_dict,
limit: Optional[int] = None, limit: Optional[int] = None,
cache_requests=False, cache_requests: bool = False,
rewrite_requests_cache=False, rewrite_requests_cache: bool = False,
bootstrap_iters: Optional[int] = 100000, bootstrap_iters: Optional[int] = 100000,
decontamination_ngrams_path=None,
write_out: bool = False, write_out: bool = False,
log_samples: bool = True, log_samples: bool = True,
verbosity: str = "INFO", verbosity: str = "INFO",
...@@ -307,21 +299,21 @@ def evaluate( ...@@ -307,21 +299,21 @@ def evaluate(
""" """
eval_logger.setLevel(getattr(logging, f"{verbosity}")) eval_logger.setLevel(getattr(logging, f"{verbosity}"))
# decontaminate = decontamination_ngrams_path is not None
# tracks all Instances/requests a model must generate output on. # tracks all Instances/requests a model must generate output on.
requests = collections.defaultdict(list) requests = defaultdict(list)
# stores the amount to pad out reqs per req. type so that # stores the amount to pad out reqs per req. type so that
# number of fwd passes per distributed rank is equal # number of fwd passes per distributed rank is equal
padding_requests = collections.defaultdict(int) padding_requests = defaultdict(int)
# get lists of group hierarchy and each type of request # get lists of group hierarchy and each type of request
task_hierarchy, eval_tasks = get_task_list(task_dict) task_hierarchy, eval_tasks = get_task_list(task_dict)
if not log_samples: if not log_samples:
assert all( if not all(
"bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys() "bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
for task_output in eval_tasks for task_output in eval_tasks
), "log_samples must be True for 'bypass' only tasks" ):
raise ValueError("log_samples must be True for 'bypass' metric-only tasks")
for task_output in eval_tasks: for task_output in eval_tasks:
task: Task = task_output.task task: Task = task_output.task
limit = get_sample_size(task, limit) limit = get_sample_size(task, limit)
...@@ -394,7 +386,7 @@ def evaluate( ...@@ -394,7 +386,7 @@ def evaluate(
# # unpack results and sort back in order and return control to Task # # unpack results and sort back in order and return control to Task
# TODO: make it possible to use a different metric per filter # TODO: make it possible to use a different metric per filter
# Pre-process task.instances to group by doc_id # Pre-process task.instances to group by doc_id
instances_by_doc_id = collections.defaultdict(list) instances_by_doc_id = defaultdict(list)
for instance in task.instances: for instance in task.instances:
instances_by_doc_id[instance.doc_id].append(instance) instances_by_doc_id[instance.doc_id].append(instance)
# Sort instances within each group # Sort instances within each group
...@@ -521,10 +513,9 @@ def evaluate( ...@@ -521,10 +513,9 @@ def evaluate(
results[group]["samples"] = sum(sizes) results[group]["samples"] = sum(sizes)
results_agg = collections.defaultdict(dict) results_agg = defaultdict(dict)
groups_agg = collections.defaultdict(dict) groups_agg = defaultdict(dict)
all_tasks_list = list(task_hierarchy.keys()) all_tasks_list = list(task_hierarchy.keys())
left_tasks_list = []
while True: while True:
add_tasks_list = list(k for k in results_agg.keys()) add_tasks_list = list(k for k in results_agg.keys())
left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list))) left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list)))
...@@ -548,7 +539,7 @@ def evaluate( ...@@ -548,7 +539,7 @@ def evaluate(
results_dict = { results_dict = {
"results": dict(results_agg.items()), "results": dict(results_agg.items()),
**({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}), **({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}),
"group_subtasks": {k: v for k, v in reversed(task_hierarchy.items())}, "group_subtasks": dict(reversed(task_hierarchy.items())),
"configs": dict(sorted(configs.items())), "configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())), "versions": dict(sorted(versions.items())),
"n-shot": dict(sorted(num_fewshot.items())), "n-shot": dict(sorted(num_fewshot.items())),
...@@ -564,11 +555,9 @@ def evaluate( ...@@ -564,11 +555,9 @@ def evaluate(
def request_caching_arg_to_dict(cache_requests: str) -> dict: def request_caching_arg_to_dict(cache_requests: str) -> dict:
request_caching_args = { request_caching_args = {
"cache_requests": ( "cache_requests": cache_requests in {"true", "refresh"},
True if cache_requests == "true" or cache_requests == "refresh" else False "rewrite_requests_cache": cache_requests == "refresh",
), "delete_requests_cache": cache_requests == "delete",
"rewrite_requests_cache": True if cache_requests == "refresh" else False,
"delete_requests_cache": True if cache_requests == "delete" else False,
} }
return request_caching_args return request_caching_args
import abc
import collections import collections
import logging import logging
import os import os
from functools import partial from functools import partial
from typing import Dict, List, Union from typing import Dict, List, Mapping, Optional, Union
from lm_eval import utils from lm_eval import utils
from lm_eval.api.task import ConfigurableTask, Task from lm_eval.api.task import ConfigurableTask, Task
...@@ -15,7 +14,7 @@ class TaskManager: ...@@ -15,7 +14,7 @@ class TaskManager:
""" """
def __init__(self, verbosity="INFO", include_path=None) -> None: def __init__(self, verbosity="INFO", include_path: Optional[str] = None) -> None:
self.verbosity = verbosity self.verbosity = verbosity
self.include_path = include_path self.include_path = include_path
self.logger = utils.eval_logger self.logger = utils.eval_logger
...@@ -26,8 +25,8 @@ class TaskManager: ...@@ -26,8 +25,8 @@ class TaskManager:
self.task_group_map = collections.defaultdict(list) self.task_group_map = collections.defaultdict(list)
def initialize_tasks(self, include_path: str = None): def initialize_tasks(self, include_path: Optional[str] = None):
"""Creates an dictionary of tasks index. """Creates a dictionary of tasks index.
:param include_path: str = None :param include_path: str = None
An additional path to be searched for tasks An additional path to be searched for tasks
...@@ -59,7 +58,7 @@ class TaskManager: ...@@ -59,7 +58,7 @@ class TaskManager:
def match_tasks(self, task_list): def match_tasks(self, task_list):
return utils.pattern_match(task_list, self.all_tasks) return utils.pattern_match(task_list, self.all_tasks)
def _name_is_registered(self, name): def _name_is_registered(self, name) -> bool:
if name in self.all_tasks: if name in self.all_tasks:
return True return True
return False return False
...@@ -69,7 +68,7 @@ class TaskManager: ...@@ -69,7 +68,7 @@ class TaskManager:
return True return True
return False return False
def _name_is_group(self, name): def _name_is_group(self, name) -> bool:
if self._name_is_registered(name) and ( if self._name_is_registered(name) and (
self.task_index[name]["type"] == "group" self.task_index[name]["type"] == "group"
): ):
...@@ -83,27 +82,29 @@ class TaskManager: ...@@ -83,27 +82,29 @@ class TaskManager:
return True return True
return False return False
def _config_is_task(self, config): def _config_is_task(self, config) -> bool:
if ("task" in config) and isinstance(config["task"], str): if ("task" in config) and isinstance(config["task"], str):
return True return True
return False return False
def _config_is_group(self, config): def _config_is_group(self, config) -> bool:
if ("task" in config) and isinstance(config["task"], list): if ("task" in config) and isinstance(config["task"], list):
return True return True
return False return False
def _config_is_python_task(self, config): def _config_is_python_task(self, config) -> bool:
if "class" in config: if "class" in config:
return True return True
return False return False
def _get_yaml_path(self, name): def _get_yaml_path(self, name):
assert name in self.task_index if name not in self.task_index:
raise ValueError
return self.task_index[name]["yaml_path"] return self.task_index[name]["yaml_path"]
def _get_config(self, name): def _get_config(self, name):
assert name in self.task_index if name not in self.task_index:
raise ValueError
yaml_path = self._get_yaml_path(name) yaml_path = self._get_yaml_path(name)
if yaml_path == -1: if yaml_path == -1:
return {} return {}
...@@ -111,7 +112,8 @@ class TaskManager: ...@@ -111,7 +112,8 @@ class TaskManager:
return utils.load_yaml_config(yaml_path, mode="full") return utils.load_yaml_config(yaml_path, mode="full")
def _get_tasklist(self, name): def _get_tasklist(self, name):
assert self._name_is_task(name) is False if self._name_is_task(name):
raise ValueError
return self.task_index[name]["task"] return self.task_index[name]["task"]
def _process_alias(self, config, group=None): def _process_alias(self, config, group=None):
...@@ -125,14 +127,15 @@ class TaskManager: ...@@ -125,14 +127,15 @@ class TaskManager:
def _load_individual_task_or_group( def _load_individual_task_or_group(
self, self,
name_or_config: Union[str, dict] = None, name_or_config: Optional[Union[str, dict]] = None,
parent_name: str = None, parent_name: Optional[str] = None,
update_config: dict = None, update_config: Optional[dict] = None,
yaml_path: str = None, yaml_path: Optional[str] = None,
) -> ConfigurableTask: ) -> Mapping:
def load_task(config, task, group=None, yaml_path=None): def load_task(config, task, group=None, yaml_path=None):
if "include" in config: if "include" in config:
assert yaml_path is not None if yaml_path is None:
raise ValueError
config.update( config.update(
utils.load_yaml_config( utils.load_yaml_config(
yaml_path, yaml_path,
...@@ -166,7 +169,7 @@ class TaskManager: ...@@ -166,7 +169,7 @@ class TaskManager:
# This checks if we're at the root. # This checks if we're at the root.
if parent_name is None: if parent_name is None:
group_config = self._get_config(name_or_config) group_config = self._get_config(name_or_config)
if set(group_config.keys()) > set(["task", "group"]): if set(group_config.keys()) > {"task", "group"}:
update_config = { update_config = {
k: v k: v
for k, v in group_config.items() for k, v in group_config.items()
...@@ -228,7 +231,7 @@ class TaskManager: ...@@ -228,7 +231,7 @@ class TaskManager:
else: else:
group_name = name_or_config["group"] group_name = name_or_config["group"]
subtask_list = name_or_config["task"] subtask_list = name_or_config["task"]
if set(name_or_config.keys()) > set(["task", "group"]): if set(name_or_config.keys()) > {"task", "group"}:
update_config = { update_config = {
k: v k: v
for k, v in name_or_config.items() for k, v in name_or_config.items()
...@@ -251,7 +254,7 @@ class TaskManager: ...@@ -251,7 +254,7 @@ class TaskManager:
} }
return all_subtasks return all_subtasks
def load_task_or_group(self, task_list: Union[str, list] = None) -> dict: def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict:
"""Loads a dictionary of task objects from a list """Loads a dictionary of task objects from a list
:param task_list: Union[str, list] = None :param task_list: Union[str, list] = None
...@@ -272,7 +275,7 @@ class TaskManager: ...@@ -272,7 +275,7 @@ class TaskManager:
return self._load_individual_task_or_group(config) return self._load_individual_task_or_group(config)
def _get_task_and_group(self, task_dir: str): def _get_task_and_group(self, task_dir: str):
"""Creates an dictionary of tasks index with the following metadata, """Creates a dictionary of tasks index with the following metadata,
- `type`, that can be either `task`, `python_task`, or `group`. - `type`, that can be either `task`, `python_task`, or `group`.
`task` refer to regular task configs, `python_task` are special `task` refer to regular task configs, `python_task` are special
yaml files that only consists of `task` and `class` parameters. yaml files that only consists of `task` and `class` parameters.
...@@ -358,7 +361,8 @@ def include_path(task_dir): ...@@ -358,7 +361,8 @@ def include_path(task_dir):
logger.setLevel(getattr(logging, "INFO")) logger.setLevel(getattr(logging, "INFO"))
logger.info( logger.info(
"To still use tasks loaded from args.include_path," "To still use tasks loaded from args.include_path,"
"see an example of the new TaskManager API in https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md#external-library-usage" "see an example of the new TaskManager API in "
"https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md#external-library-usage"
) )
return 0 return 0
...@@ -397,7 +401,8 @@ def get_task_name_from_object(task_object): ...@@ -397,7 +401,8 @@ def get_task_name_from_object(task_object):
def get_task_dict( def get_task_dict(
task_name_list: List[Union[str, Dict, Task]], task_manager: TaskManager = None task_name_list: List[Union[str, Dict, Task]],
task_manager: Optional[TaskManager] = None,
): ):
"""Creates a dictionary of task objects from either a name of task, config, or prepared Task object. """Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
...@@ -442,9 +447,10 @@ def get_task_dict( ...@@ -442,9 +447,10 @@ def get_task_dict(
get_task_name_from_object(task_element): task_element, get_task_name_from_object(task_element): task_element,
} }
assert set(task_name_from_string_dict.keys()).isdisjoint( if not set(task_name_from_string_dict.keys()).isdisjoint(
set(task_name_from_object_dict.keys()) set(task_name_from_object_dict.keys())
) ):
raise ValueError
return { return {
**task_name_from_string_dict, **task_name_from_string_dict,
......
...@@ -95,9 +95,9 @@ all = [ ...@@ -95,9 +95,9 @@ all = [
[tool.ruff.lint] [tool.ruff.lint]
extend-select = ["I"] extend-select = ["I"]
[tool.ruff.isort] [tool.ruff.lint.isort]
lines-after-imports = 2 lines-after-imports = 2
known-first-party = ["lm_eval"] known-first-party = ["lm_eval"]
[tool.ruff.extend-per-file-ignores] [tool.ruff.lint.extend-per-file-ignores]
"__init__.py" = ["F401","F402","F403"] "__init__.py" = ["F401","F402","F403"]
...@@ -3,7 +3,7 @@ import random ...@@ -3,7 +3,7 @@ import random
import transformers import transformers
from lm_eval import evaluator, tasks from lm_eval import evaluator, tasks
from lm_eval.base import LM from lm_eval.api.model import LM
class DryrunLM(LM): class DryrunLM(LM):
...@@ -53,13 +53,12 @@ def main(): ...@@ -53,13 +53,12 @@ def main():
values = [] values = []
for taskname in task_list.split(","): for taskname in task_list.split(","):
lm.tokencost = 0 lm.tokencost = 0
evaluator.evaluate( evaluator.simple_evaluate(
lm=lm, lm=lm,
task_dict={taskname: tasks.get_task(taskname)()}, task_dict={taskname: tasks.get_task(taskname)()},
num_fewshot=0, num_fewshot=0,
limit=None, limit=None,
bootstrap_iters=10, bootstrap_iters=10,
description_dict=None,
) )
print(taskname, lm.tokencost) print(taskname, lm.tokencost)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment