Commit 00afd536 authored by Baber's avatar Baber
Browse files

fix tests

parent 1de882c2
...@@ -3,13 +3,18 @@ from __future__ import annotations ...@@ -3,13 +3,18 @@ from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from itertools import chain from itertools import chain
from pathlib import Path from pathlib import Path
from typing import Any from typing import TYPE_CHECKING, Any
from lm_eval.api.task import Task
from lm_eval.tasks.factory import TaskFactory from lm_eval.tasks.factory import TaskFactory
from lm_eval.tasks.index import Entry, Kind, TaskIndex from lm_eval.tasks.index import Entry, Kind, TaskIndex
from lm_eval.utils import setup_logging from lm_eval.utils import setup_logging
if TYPE_CHECKING:
from lm_eval.api.task import Task
class TaskManager: class TaskManager:
def __init__( def __init__(
self, self,
...@@ -77,3 +82,20 @@ class TaskManager: ...@@ -77,3 +82,20 @@ class TaskManager:
if isinstance(task_list, list) if isinstance(task_list, list)
else [self.load_spec(task_list)] else [self.load_spec(task_list)]
) )
def get_task_dict(
task_name_list: str | list[str | dict | Task],
task_manager: TaskManager | None = None,
):
if not task_manager:
task_manager = TaskManager()
else:
assert isinstance(task_manager, TaskManager)
return {
task_name: task_manager.load_spec(task_name)
if isinstance(task_name, str)
else task_name
for task_name in task_name_list
}
import importlib # import importlib
import os # import os
import sys # import sys
from datetime import datetime # from datetime import datetime
from typing import List, Optional, Tuple # from typing import List, Optional, Tuple
#
import pytest # import pytest
import torch # import torch
#
from lm_eval.caching.cache import PATH # from lm_eval.caching.cache import PATH
#
#
MODULE_DIR = os.path.dirname(os.path.realpath(__file__)) # MODULE_DIR = os.path.dirname(os.path.realpath(__file__))
#
# NOTE the script this loads uses simple evaluate # # NOTE the script this loads uses simple evaluate
# TODO potentially test both the helper script and the normal script # # TODO potentially test both the helper script and the normal script
sys.path.append(f"{MODULE_DIR}/../scripts") # sys.path.append(f"{MODULE_DIR}/../scripts")
model_loader = importlib.import_module("requests_caching") # model_loader = importlib.import_module("requests_caching")
run_model_for_task_caching = model_loader.run_model_for_task_caching # run_model_for_task_caching = model_loader.run_model_for_task_caching
#
os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "1" # os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "1"
DEFAULT_TASKS = ["lambada_openai", "sciq"] # DEFAULT_TASKS = ["lambada_openai", "sciq"]
#
#
@pytest.fixture(autouse=True) # @pytest.fixture(autouse=True)
def setup_and_teardown(): # def setup_and_teardown():
# Setup # # Setup
torch.use_deterministic_algorithms(False) # torch.use_deterministic_algorithms(False)
clear_cache() # clear_cache()
# Yields control back to the test function # # Yields control back to the test function
yield # yield
# Cleanup here # # Cleanup here
#
#
def clear_cache(): # def clear_cache():
if os.path.exists(PATH): # if os.path.exists(PATH):
cache_files = os.listdir(PATH) # cache_files = os.listdir(PATH)
for file in cache_files: # for file in cache_files:
file_path = f"{PATH}/{file}" # file_path = f"{PATH}/{file}"
os.unlink(file_path) # os.unlink(file_path)
#
#
# leaving tasks here to allow for the option to select specific task files # # leaving tasks here to allow for the option to select specific task files
def get_cache_files(tasks: Optional[List[str]] = None) -> Tuple[List[str], List[str]]: # def get_cache_files(tasks: Optional[List[str]] = None) -> Tuple[List[str], List[str]]:
cache_files = os.listdir(PATH) # cache_files = os.listdir(PATH)
#
file_task_names = [] # file_task_names = []
#
for file in cache_files: # for file in cache_files:
file_without_prefix = file.split("-")[1] # file_without_prefix = file.split("-")[1]
file_without_prefix_and_suffix = file_without_prefix.split(".")[0] # file_without_prefix_and_suffix = file_without_prefix.split(".")[0]
file_task_names.extend([file_without_prefix_and_suffix]) # file_task_names.extend([file_without_prefix_and_suffix])
#
return cache_files, file_task_names # return cache_files, file_task_names
#
#
def assert_created(tasks: List[str], file_task_names: List[str]): # def assert_created(tasks: List[str], file_task_names: List[str]):
tasks.sort() # tasks.sort()
file_task_names.sort() # file_task_names.sort()
#
assert tasks == file_task_names # assert tasks == file_task_names
#
#
@pytest.mark.parametrize("tasks", [DEFAULT_TASKS]) # @pytest.mark.parametrize("tasks", [DEFAULT_TASKS])
def requests_caching_true(tasks: List[str]): # def requests_caching_true(tasks: List[str]):
run_model_for_task_caching(tasks=tasks, cache_requests="true") # run_model_for_task_caching(tasks=tasks, cache_requests="true")
#
cache_files, file_task_names = get_cache_files() # cache_files, file_task_names = get_cache_files()
print(file_task_names) # print(file_task_names)
assert_created(tasks=tasks, file_task_names=file_task_names) # assert_created(tasks=tasks, file_task_names=file_task_names)
#
#
@pytest.mark.parametrize("tasks", [DEFAULT_TASKS]) # @pytest.mark.parametrize("tasks", [DEFAULT_TASKS])
def requests_caching_refresh(tasks: List[str]): # def requests_caching_refresh(tasks: List[str]):
run_model_for_task_caching(tasks=tasks, cache_requests="true") # run_model_for_task_caching(tasks=tasks, cache_requests="true")
#
timestamp_before_test = datetime.now().timestamp() # timestamp_before_test = datetime.now().timestamp()
#
run_model_for_task_caching(tasks=tasks, cache_requests="refresh") # run_model_for_task_caching(tasks=tasks, cache_requests="refresh")
#
cache_files, file_task_names = get_cache_files() # cache_files, file_task_names = get_cache_files()
#
for file in cache_files: # for file in cache_files:
modification_time = os.path.getmtime(f"{PATH}/{file}") # modification_time = os.path.getmtime(f"{PATH}/{file}")
assert modification_time > timestamp_before_test # assert modification_time > timestamp_before_test
#
tasks.sort() # tasks.sort()
file_task_names.sort() # file_task_names.sort()
#
assert tasks == file_task_names # assert tasks == file_task_names
#
#
@pytest.mark.parametrize("tasks", [DEFAULT_TASKS]) # @pytest.mark.parametrize("tasks", [DEFAULT_TASKS])
def requests_caching_delete(tasks: List[str]): # def requests_caching_delete(tasks: List[str]):
# populate the data first, rerun this test within this test for additional confidence # # populate the data first, rerun this test within this test for additional confidence
# test_requests_caching_true(tasks=tasks) # # test_requests_caching_true(tasks=tasks)
#
run_model_for_task_caching(tasks=tasks, cache_requests="delete") # run_model_for_task_caching(tasks=tasks, cache_requests="delete")
#
cache_files, file_task_names = get_cache_files() # cache_files, file_task_names = get_cache_files()
#
assert len(cache_files) == 0 # assert len(cache_files) == 0
#
#
# useful for locally running tests through the debugger # # useful for locally running tests through the debugger
if __name__ == "__main__": # if __name__ == "__main__":
#
def run_tests(): # def run_tests():
tests = [ # tests = [
# test_requests_caching_true, # # test_requests_caching_true,
# test_requests_caching_refresh, # # test_requests_caching_refresh,
# test_requests_caching_delete, # # test_requests_caching_delete,
] # ]
# Lookups of global names within a loop is inefficient, so copy to a local variable outside of the loop first # # Lookups of global names within a loop is inefficient, so copy to a local variable outside of the loop first
default_tasks = DEFAULT_TASKS # default_tasks = DEFAULT_TASKS
for test_func in tests: # for test_func in tests:
clear_cache() # clear_cache()
test_func(tasks=default_tasks) # test_func(tasks=default_tasks)
#
print("Tests pass") # print("Tests pass")
#
run_tests() # run_tests()
...@@ -13,7 +13,7 @@ from pathlib import Path ...@@ -13,7 +13,7 @@ from pathlib import Path
import pytest import pytest
from lm_eval.tasks._task_index import TaskIndex, TaskKind from lm_eval.tasks.index import Kind, TaskIndex
@pytest.fixture @pytest.fixture
...@@ -33,28 +33,28 @@ def yaml_file(temp_dir): ...@@ -33,28 +33,28 @@ def yaml_file(temp_dir):
return _create_yaml return _create_yaml
class TestTaskKindOf: class TestKindOf:
"""Tests for identifying task configuration types.""" """Tests for identifying task configuration types."""
def test_kind_of_task(self): def test_kind_of_task(self):
"""Single task with string name.""" """Single task with string name."""
cfg = {"task": "my_task", "dataset_path": "data"} cfg = {"task": "my_task", "dataset_path": "data"}
assert TaskIndex._kind_of(cfg) == TaskKind.TASK assert TaskIndex._kind_of(cfg) == Kind.TASK
def test_kind_of_group(self): def test_kind_of_group(self):
"""Group has task as list.""" """Group has task as list."""
cfg = {"task": ["task1", "task2"], "group": "my_group"} cfg = {"task": ["task1", "task2"], "group": "my_group"}
assert TaskIndex._kind_of(cfg) == TaskKind.GROUP assert TaskIndex._kind_of(cfg) == Kind.GROUP
def test_kind_of_py_task(self): def test_kind_of_py_task(self):
"""Python task has class field.""" """Python task has class field."""
cfg = {"task": "my_task", "class": "tasks.MyTask"} cfg = {"task": "my_task", "class": "tasks.MyTask"}
assert TaskIndex._kind_of(cfg) == TaskKind.PY_TASK assert TaskIndex._kind_of(cfg) == Kind.PY_TASK
def test_kind_of_task_list(self): def test_kind_of_task_list(self):
"""Task list has task_list field.""" """Task list has task_list field."""
cfg = {"task_list": ["task1", "task2"]} cfg = {"task_list": ["task1", "task2"]}
assert TaskIndex._kind_of(cfg) == TaskKind.TASK_LIST assert TaskIndex._kind_of(cfg) == Kind.TASK_LIST
def test_kind_of_unknown(self): def test_kind_of_unknown(self):
"""Unknown config raises ValueError.""" """Unknown config raises ValueError."""
...@@ -75,7 +75,7 @@ class TestIterYamlFiles: ...@@ -75,7 +75,7 @@ class TestIterYamlFiles:
(temp_dir / "other.txt").touch() (temp_dir / "other.txt").touch()
builder = TaskIndex() builder = TaskIndex()
yaml_files = list(builder._iter_yaml_files()) yaml_files = list(builder._iter_yaml_files(temp_dir))
assert len(yaml_files) == 2 assert len(yaml_files) == 2
names = {f.name for f in yaml_files} names = {f.name for f in yaml_files}
...@@ -90,7 +90,7 @@ class TestIterYamlFiles: ...@@ -90,7 +90,7 @@ class TestIterYamlFiles:
(temp_dir / ".ipynb_checkpoints" / "also_ignored.yaml").touch() (temp_dir / ".ipynb_checkpoints" / "also_ignored.yaml").touch()
builder = TaskIndex() builder = TaskIndex()
yaml_files = list(builder._iter_yaml_files()) yaml_files = list(builder._iter_yaml_files(temp_dir))
assert len(yaml_files) == 1 assert len(yaml_files) == 1
assert yaml_files[0].name == "task.yaml" assert yaml_files[0].name == "task.yaml"
...@@ -111,7 +111,7 @@ class TestProcessCfg: ...@@ -111,7 +111,7 @@ class TestProcessCfg:
assert "my_task" in index assert "my_task" in index
entry = index["my_task"] entry = index["my_task"]
assert entry.name == "my_task" assert entry.name == "my_task"
assert entry.kind == TaskKind.TASK assert entry.kind == Kind.TASK
assert entry.yaml_path == path assert entry.yaml_path == path
assert entry.tags == {"tag1", "tag2"} assert entry.tags == {"tag1", "tag2"}
...@@ -127,7 +127,7 @@ class TestProcessCfg: ...@@ -127,7 +127,7 @@ class TestProcessCfg:
assert "my_group" in index assert "my_group" in index
entry = index["my_group"] entry = index["my_group"]
assert entry.name == "my_group" assert entry.name == "my_group"
assert entry.kind == TaskKind.GROUP assert entry.kind == Kind.GROUP
assert entry.yaml_path == path assert entry.yaml_path == path
assert entry.tags == {"grp_tag"} assert entry.tags == {"grp_tag"}
...@@ -143,7 +143,7 @@ class TestProcessCfg: ...@@ -143,7 +143,7 @@ class TestProcessCfg:
assert "py_task" in index assert "py_task" in index
entry = index["py_task"] entry = index["py_task"]
assert entry.name == "py_task" assert entry.name == "py_task"
assert entry.kind == TaskKind.PY_TASK assert entry.kind == Kind.PY_TASK
assert entry.yaml_path is None # Python tasks don't store yaml_path assert entry.yaml_path is None # Python tasks don't store yaml_path
assert entry.tags == {"py_tag"} assert entry.tags == {"py_tag"}
...@@ -181,14 +181,14 @@ class TestProcessCfg: ...@@ -181,14 +181,14 @@ class TestProcessCfg:
# Task without tags # Task without tags
assert "task1" in index assert "task1" in index
task1 = index["task1"] task1 = index["task1"]
assert task1.kind == TaskKind.TASK assert task1.kind == Kind.TASK
assert task1.yaml_path == path assert task1.yaml_path == path
assert task1.tags == set() assert task1.tags == set()
# Task with tags # Task with tags
assert "task2" in index assert "task2" in index
task2 = index["task2"] task2 = index["task2"]
assert task2.kind == TaskKind.TASK assert task2.kind == Kind.TASK
assert task2.yaml_path == path assert task2.yaml_path == path
assert task2.tags == {"tag1", "tag2"} assert task2.tags == {"tag1", "tag2"}
...@@ -205,7 +205,7 @@ class TestRegisterTags: ...@@ -205,7 +205,7 @@ class TestRegisterTags:
assert "my_tag" in index assert "my_tag" in index
tag_entry = index["my_tag"] tag_entry = index["my_tag"]
assert tag_entry.kind == TaskKind.TAG assert tag_entry.kind == Kind.TAG
assert tag_entry.yaml_path is None assert tag_entry.yaml_path is None
assert "task1" in tag_entry.tags # TAG entries use tags set for task names assert "task1" in tag_entry.tags # TAG entries use tags set for task names
...@@ -252,7 +252,7 @@ class TestBuild: ...@@ -252,7 +252,7 @@ class TestBuild:
assert len(index) == 1 assert len(index) == 1
assert "my_task" in index assert "my_task" in index
assert index["my_task"].kind == TaskKind.TASK assert index["my_task"].kind == Kind.TASK
def test_build_mixed_types(self, temp_dir, yaml_file): def test_build_mixed_types(self, temp_dir, yaml_file):
"""Discovers various task types.""" """Discovers various task types."""
...@@ -283,12 +283,12 @@ class TestBuild: ...@@ -283,12 +283,12 @@ class TestBuild:
assert "common" in index # Tag entry assert "common" in index # Tag entry
# Check types # Check types
assert index["task1"].kind == TaskKind.TASK assert index["task1"].kind == Kind.TASK
assert index["group1"].kind == TaskKind.GROUP assert index["group1"].kind == Kind.GROUP
assert index["task2"].kind == TaskKind.TASK assert index["task2"].kind == Kind.TASK
assert index["task3"].kind == TaskKind.TASK assert index["task3"].kind == Kind.TASK
assert index["py_task"].kind == TaskKind.PY_TASK assert index["py_task"].kind == Kind.PY_TASK
assert index["common"].kind == TaskKind.TAG assert index["common"].kind == Kind.TAG
# Check tag has both tasks # Check tag has both tasks
assert index["common"].tags == {"task1", "task3"} assert index["common"].tags == {"task1", "task3"}
......
...@@ -64,10 +64,10 @@ def test_python_task_inclusion( ...@@ -64,10 +64,10 @@ def test_python_task_inclusion(
verbosity="INFO", include_path=str(custom_task_files_dir) verbosity="INFO", include_path=str(custom_task_files_dir)
) )
# check if python tasks enters the global task_index # check if python tasks enters the global task_index
assert custom_task_name in task_manager.task_index assert custom_task_name in task_manager._index
# check if subtask is present # check if subtask is present
assert custom_task_name in task_manager.all_subtasks assert custom_task_name in task_manager._index
# check if tag is present # check if tag is present
assert custom_task_tag in task_manager.all_tags assert custom_task_tag in task_manager._index
# check if it can be loaded by tag (custom_task_tag) # check if it can be loaded by tag (custom_task_tag)
assert custom_task_name in task_manager.load_task_or_group(custom_task_tag) assert custom_task_name in task_manager.load_task_or_group(custom_task_tag)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment