Unverified Commit 4bb77e82 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

add include_defaults kwarg to taskmanager, add tests for include_path (#1856)

parent d0f6e011
......@@ -14,27 +14,43 @@ class TaskManager:
"""
def __init__(self, verbosity="INFO", include_path: Optional[str] = None) -> None:
def __init__(
self,
verbosity="INFO",
include_path: Optional[Union[str, List]] = None,
include_defaults: bool = True,
) -> None:
self.verbosity = verbosity
self.include_path = include_path
self.logger = utils.eval_logger
self.logger.setLevel(getattr(logging, f"{verbosity}"))
self._task_index = self.initialize_tasks(include_path=include_path)
self._task_index = self.initialize_tasks(
include_path=include_path, include_defaults=include_defaults
)
self._all_tasks = sorted(list(self._task_index.keys()))
self.task_group_map = collections.defaultdict(list)
def initialize_tasks(self, include_path: Optional[str] = None):
def initialize_tasks(
self,
include_path: Optional[Union[str, List]] = None,
include_defaults: bool = True,
):
"""Creates a dictionary of tasks index.
:param include_path: str = None
An additional path to be searched for tasks
:param include_path: Union[str, List] = None
An additional path to be searched for tasks recursively.
Can provide more than one such path as a list.
:param include_defaults: bool = True
If set to false, default tasks (those in lm_eval/tasks/) are not indexed.
:return
Dictionary of task names as key and task metadata
"""
all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
if include_defaults:
all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
else:
all_paths = []
if include_path is not None:
if isinstance(include_path, str):
include_path = [include_path]
......
import os
import pytest
import lm_eval.api as api
import lm_eval.evaluator as evaluator
from lm_eval import tasks
@pytest.mark.parametrize(
"limit,model,model_args",
[
(
10,
"hf",
"pretrained=EleutherAI/pythia-160m,dtype=float32,device=cpu",
),
],
)
def test_include_correctness(limit: int, model: str, model_args: str):
task_name = ["arc_easy"]
task_manager = tasks.TaskManager()
task_dict = tasks.get_task_dict(task_name, task_manager)
e1 = evaluator.simple_evaluate(
model=model,
tasks=task_name,
limit=limit,
model_args=model_args,
)
assert e1 is not None
# run with evaluate() and "arc_easy" test config (included from ./testconfigs path)
lm = api.registry.get_model(model).create_from_arg_string(
model_args,
{
"batch_size": None,
"max_batch_size": None,
"device": None,
},
)
task_name = ["arc_easy"]
task_manager = tasks.TaskManager(
include_path=os.path.dirname(os.path.abspath(__file__)) + "/testconfigs",
include_defaults=False,
)
task_dict = tasks.get_task_dict(task_name, task_manager)
e2 = evaluator.evaluate(
lm=lm,
task_dict=task_dict,
limit=limit,
)
assert e2 is not None
# check that caching is working
def r(x):
return x["results"]["arc_easy"]
assert all(
x == y
for x, y in zip([y for _, y in r(e1).items()], [y for _, y in r(e2).items()])
)
# test that setting include_defaults = False works as expected and that include_path works
def test_no_include_defaults():
task_name = ["arc_easy"]
task_manager = tasks.TaskManager(
include_path=os.path.dirname(os.path.abspath(__file__)) + "/testconfigs",
include_defaults=False,
)
# should succeed, because we've included an 'arc_easy' task from this dir
task_dict = tasks.get_task_dict(task_name, task_manager)
# should fail, since ./testconfigs has no arc_challenge task
task_name = ["arc_challenge"]
with pytest.raises(KeyError):
task_dict = tasks.get_task_dict(task_name, task_manager) # noqa: F841
# test that include_path containing a task shadowing another task's name fails
# def test_shadowed_name_fails():
# task_name = ["arc_easy"]
# task_manager = tasks.TaskManager(include_path=os.path.dirname(os.path.abspath(__file__)) + "/testconfigs")
# task_dict = tasks.get_task_dict(task_name, task_manager)
task: arc_easy
dataset_path: allenai/ai2_arc
dataset_name: ARC-Easy
output_type: multiple_choice
training_split: train
validation_split: validation
test_split: test
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{choices.label.index(answerKey)}}"
doc_to_choice: "{{choices.text}}"
should_decontaminate: true
doc_to_decontamination_query: "Question: {{question}}\nAnswer:"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
- metric: acc_norm
aggregation: mean
higher_is_better: true
metadata:
version: 1.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment