Commit f66fc06f authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

fix merge conflicts

parents b13753cd d714fc95
...@@ -7,6 +7,9 @@ import lm_eval.tasks as tasks ...@@ -7,6 +7,9 @@ import lm_eval.tasks as tasks
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
task_manager = tasks.TaskManager()
@pytest.mark.skip(reason="requires CUDA") @pytest.mark.skip(reason="requires CUDA")
class TEST_VLLM: class TEST_VLLM:
vllm = pytest.importorskip("vllm") vllm = pytest.importorskip("vllm")
...@@ -17,15 +20,15 @@ class TEST_VLLM: ...@@ -17,15 +20,15 @@ class TEST_VLLM:
except ModuleNotFoundError: except ModuleNotFoundError:
pass pass
torch.use_deterministic_algorithms(True) torch.use_deterministic_algorithms(True)
tasks.initialize_tasks() task_list = task_manager.load_task_or_group(["arc_easy", "gsm8k", "wikitext"])
multiple_choice_task = tasks.TASK_REGISTRY.get("arc_easy")() # type: ignore multiple_choice_task = task_list["arc_easy"] # type: ignore
multiple_choice_task.build_all_requests(limit=10, rank=0, world_size=1) multiple_choice_task.build_all_requests(limit=10, rank=0, world_size=1)
MULTIPLE_CH: List[Instance] = multiple_choice_task.instances MULTIPLE_CH: List[Instance] = multiple_choice_task.instances
generate_until_task = tasks.TASK_REGISTRY.get("gsm8k")() # type: ignore generate_until_task = task_list["gsm8k"] # type: ignore
generate_until_task.build_all_requests(limit=10, rank=0, world_size=1) generate_until_task.build_all_requests(limit=10, rank=0, world_size=1)
generate_until_task._config.generation_kwargs["max_gen_toks"] = 10 generate_until_task._config.generation_kwargs["max_gen_toks"] = 10
generate_until: List[Instance] = generate_until_task.instances generate_until: List[Instance] = generate_until_task.instances
rolling_task = tasks.TASK_REGISTRY.get("wikitext")() # type: ignore rolling_task = task_list["wikitext"] # type: ignore
rolling_task.build_all_requests(limit=10, rank=0, world_size=1) rolling_task.build_all_requests(limit=10, rank=0, world_size=1)
ROLLING: List[Instance] = rolling_task.instances ROLLING: List[Instance] = rolling_task.instances
......
...@@ -6,11 +6,9 @@ import pytest ...@@ -6,11 +6,9 @@ import pytest
# import lm_eval.models as models # import lm_eval.models as models
import lm_eval.api as api import lm_eval.api as api
import lm_eval.evaluator as evaluator import lm_eval.evaluator as evaluator
import lm_eval.tasks as tasks from lm_eval import tasks
tasks.initialize_tasks()
# TODO: more fine grained unit tests rather than this big honking integration # TODO: more fine grained unit tests rather than this big honking integration
# test once we break evaluator into smaller, more manageable pieces # test once we break evaluator into smaller, more manageable pieces
...@@ -46,7 +44,8 @@ def test_evaluator(task_name: List[str], limit: int, model: str, model_args: str ...@@ -46,7 +44,8 @@ def test_evaluator(task_name: List[str], limit: int, model: str, model_args: str
"device": None, "device": None,
}, },
) )
task_dict = tasks.get_task_dict(task_name, num_fewshot=0) task_manager = tasks.TaskManager()
task_dict = tasks.get_task_dict(task_name, task_manager)
e2 = evaluator.evaluate( e2 = evaluator.evaluate(
lm=lm, lm=lm,
......
...@@ -10,7 +10,7 @@ from .utils import new_tasks ...@@ -10,7 +10,7 @@ from .utils import new_tasks
datasets.disable_caching() datasets.disable_caching()
tasks.initialize_tasks() task_manager = tasks.TaskManager()
# Default Task # Default Task
TASKS = ["arc_easy"] TASKS = ["arc_easy"]
...@@ -21,9 +21,9 @@ def task_class(): ...@@ -21,9 +21,9 @@ def task_class():
task_classes = new_tasks() task_classes = new_tasks()
# Check if task_classes is empty # Check if task_classes is empty
if task_classes: if task_classes:
return [tasks.TASK_REGISTRY.get(x)() for x in task_classes] return list(task_manager.load_task_or_group(task_classes).values())
else: else:
return [tasks.TASK_REGISTRY.get(x)() for x in TASKS] return list(task_manager.load_task_or_group(TASKS).values())
@pytest.fixture() @pytest.fixture()
......
import os import os
from pathlib import Path
from typing import List, Union from typing import List, Union
from lm_eval.utils import load_yaml_config from lm_eval.utils import load_yaml_config
...@@ -20,16 +19,18 @@ def load_changed_files(file_path: str) -> List[str]: ...@@ -20,16 +19,18 @@ def load_changed_files(file_path: str) -> List[str]:
# checks the txt file for list of changed files. # checks the txt file for list of changed files.
# if file ends with .yaml then check yaml for task name # if file ends with .yaml then check yaml and load the config.
# if file ends with .py then parse the folder for all yaml files # if the config task is a string, it's a task config.
# if the config task is a list, it's a group config.
def parser(full_path: List[str]) -> List[str]: def parser(full_path: List[str]) -> List[str]:
_output = set() _output = set()
for x in full_path: for x in full_path:
if x.endswith(".yaml"): if os.path.exists(x) and x.endswith(".yaml"):
_output.add(load_yaml_config(x)["task"]) config = load_yaml_config(x, mode="simple")
elif x.endswith(".py"): if isinstance(config["task"], str):
path = [str(x) for x in (list(Path(x).parent.glob("*.yaml")))] _output.add(config["task"])
_output |= {load_yaml_config(x)["task"] for x in path} elif isinstance(config["task"], list):
_output.add(config["group"])
return list(_output) return list(_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment