Commit f66fc06f authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

fix merge conflicts

parents b13753cd d714fc95
......@@ -7,6 +7,9 @@ import lm_eval.tasks as tasks
from lm_eval.api.instance import Instance
task_manager = tasks.TaskManager()
@pytest.mark.skip(reason="requires CUDA")
class TEST_VLLM:
vllm = pytest.importorskip("vllm")
......@@ -17,15 +20,15 @@ class TEST_VLLM:
except ModuleNotFoundError:
pass
torch.use_deterministic_algorithms(True)
tasks.initialize_tasks()
multiple_choice_task = tasks.TASK_REGISTRY.get("arc_easy")() # type: ignore
task_list = task_manager.load_task_or_group(["arc_easy", "gsm8k", "wikitext"])
multiple_choice_task = task_list["arc_easy"] # type: ignore
multiple_choice_task.build_all_requests(limit=10, rank=0, world_size=1)
MULTIPLE_CH: List[Instance] = multiple_choice_task.instances
generate_until_task = tasks.TASK_REGISTRY.get("gsm8k")() # type: ignore
generate_until_task = task_list["gsm8k"] # type: ignore
generate_until_task.build_all_requests(limit=10, rank=0, world_size=1)
generate_until_task._config.generation_kwargs["max_gen_toks"] = 10
generate_until: List[Instance] = generate_until_task.instances
rolling_task = tasks.TASK_REGISTRY.get("wikitext")() # type: ignore
rolling_task = task_list["wikitext"] # type: ignore
rolling_task.build_all_requests(limit=10, rank=0, world_size=1)
ROLLING: List[Instance] = rolling_task.instances
......
......@@ -6,11 +6,9 @@ import pytest
# import lm_eval.models as models
import lm_eval.api as api
import lm_eval.evaluator as evaluator
import lm_eval.tasks as tasks
from lm_eval import tasks
tasks.initialize_tasks()
# TODO: more fine grained unit tests rather than this big honking integration
# test once we break evaluator into smaller, more manageable pieces
......@@ -46,7 +44,8 @@ def test_evaluator(task_name: List[str], limit: int, model: str, model_args: str
"device": None,
},
)
task_dict = tasks.get_task_dict(task_name, num_fewshot=0)
task_manager = tasks.TaskManager()
task_dict = tasks.get_task_dict(task_name, task_manager)
e2 = evaluator.evaluate(
lm=lm,
......
......@@ -10,7 +10,7 @@ from .utils import new_tasks
datasets.disable_caching()
tasks.initialize_tasks()
task_manager = tasks.TaskManager()
# Default Task
TASKS = ["arc_easy"]
......@@ -21,9 +21,9 @@ def task_class():
task_classes = new_tasks()
# Check if task_classes is empty
if task_classes:
return [tasks.TASK_REGISTRY.get(x)() for x in task_classes]
return list(task_manager.load_task_or_group(task_classes).values())
else:
return [tasks.TASK_REGISTRY.get(x)() for x in TASKS]
return list(task_manager.load_task_or_group(TASKS).values())
@pytest.fixture()
......
import os
from pathlib import Path
from typing import List, Union
from lm_eval.utils import load_yaml_config
......@@ -20,16 +19,18 @@ def load_changed_files(file_path: str) -> List[str]:
# checks the txt file for list of changed files.
# if file ends with .yaml then check yaml for task name
# if file ends with .py then parse the folder for all yaml files
# if file ends with .yaml then check yaml and load the config.
# if the config task is a string, it's a task config.
# if the config task is a list, it's a group config.
def parser(full_path: List[str]) -> List[str]:
_output = set()
for x in full_path:
if x.endswith(".yaml"):
_output.add(load_yaml_config(x)["task"])
elif x.endswith(".py"):
path = [str(x) for x in (list(Path(x).parent.glob("*.yaml")))]
_output |= {load_yaml_config(x)["task"] for x in path}
if os.path.exists(x) and x.endswith(".yaml"):
config = load_yaml_config(x, mode="simple")
if isinstance(config["task"], str):
_output.add(config["task"])
elif isinstance(config["task"], list):
_output.add(config["group"])
return list(_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment