"examples/community/latent_consistency_interpolate.py" did not exist on "0e0db625d0b7da2e6c27336845514b4460bd0000"
Commit 00afd536 authored by Baber's avatar Baber
Browse files

fix tests

parent 1de882c2
......@@ -3,13 +3,18 @@ from __future__ import annotations
from collections import defaultdict
from itertools import chain
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any
from lm_eval.api.task import Task
from lm_eval.tasks.factory import TaskFactory
from lm_eval.tasks.index import Entry, Kind, TaskIndex
from lm_eval.utils import setup_logging
if TYPE_CHECKING:
from lm_eval.api.task import Task
class TaskManager:
def __init__(
self,
......@@ -77,3 +82,20 @@ class TaskManager:
if isinstance(task_list, list)
else [self.load_spec(task_list)]
)
def get_task_dict(
task_name_list: str | list[str | dict | Task],
task_manager: TaskManager | None = None,
):
if not task_manager:
task_manager = TaskManager()
else:
assert isinstance(task_manager, TaskManager)
return {
task_name: task_manager.load_spec(task_name)
if isinstance(task_name, str)
else task_name
for task_name in task_name_list
}
import importlib
import os
import sys
from datetime import datetime
from typing import List, Optional, Tuple
import pytest
import torch
from lm_eval.caching.cache import PATH
MODULE_DIR = os.path.dirname(os.path.realpath(__file__))
# NOTE the script this loads uses simple evaluate
# TODO potentially test both the helper script and the normal script
sys.path.append(f"{MODULE_DIR}/../scripts")
model_loader = importlib.import_module("requests_caching")
run_model_for_task_caching = model_loader.run_model_for_task_caching
os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "1"
DEFAULT_TASKS = ["lambada_openai", "sciq"]
@pytest.fixture(autouse=True)
def setup_and_teardown():
# Setup
torch.use_deterministic_algorithms(False)
clear_cache()
# Yields control back to the test function
yield
# Cleanup here
def clear_cache():
if os.path.exists(PATH):
cache_files = os.listdir(PATH)
for file in cache_files:
file_path = f"{PATH}/{file}"
os.unlink(file_path)
# leaving tasks here to allow for the option to select specific task files
def get_cache_files(tasks: Optional[List[str]] = None) -> Tuple[List[str], List[str]]:
cache_files = os.listdir(PATH)
file_task_names = []
for file in cache_files:
file_without_prefix = file.split("-")[1]
file_without_prefix_and_suffix = file_without_prefix.split(".")[0]
file_task_names.extend([file_without_prefix_and_suffix])
return cache_files, file_task_names
def assert_created(tasks: List[str], file_task_names: List[str]):
tasks.sort()
file_task_names.sort()
assert tasks == file_task_names
@pytest.mark.parametrize("tasks", [DEFAULT_TASKS])
def requests_caching_true(tasks: List[str]):
run_model_for_task_caching(tasks=tasks, cache_requests="true")
cache_files, file_task_names = get_cache_files()
print(file_task_names)
assert_created(tasks=tasks, file_task_names=file_task_names)
@pytest.mark.parametrize("tasks", [DEFAULT_TASKS])
def requests_caching_refresh(tasks: List[str]):
run_model_for_task_caching(tasks=tasks, cache_requests="true")
timestamp_before_test = datetime.now().timestamp()
run_model_for_task_caching(tasks=tasks, cache_requests="refresh")
cache_files, file_task_names = get_cache_files()
for file in cache_files:
modification_time = os.path.getmtime(f"{PATH}/{file}")
assert modification_time > timestamp_before_test
tasks.sort()
file_task_names.sort()
assert tasks == file_task_names
@pytest.mark.parametrize("tasks", [DEFAULT_TASKS])
def requests_caching_delete(tasks: List[str]):
# populate the data first, rerun this test within this test for additional confidence
# test_requests_caching_true(tasks=tasks)
run_model_for_task_caching(tasks=tasks, cache_requests="delete")
cache_files, file_task_names = get_cache_files()
assert len(cache_files) == 0
# useful for locally running tests through the debugger
if __name__ == "__main__":
def run_tests():
tests = [
# test_requests_caching_true,
# test_requests_caching_refresh,
# test_requests_caching_delete,
]
# Lookups of global names within a loop is inefficient, so copy to a local variable outside of the loop first
default_tasks = DEFAULT_TASKS
for test_func in tests:
clear_cache()
test_func(tasks=default_tasks)
print("Tests pass")
run_tests()
# import importlib
# import os
# import sys
# from datetime import datetime
# from typing import List, Optional, Tuple
#
# import pytest
# import torch
#
# from lm_eval.caching.cache import PATH
#
#
# MODULE_DIR = os.path.dirname(os.path.realpath(__file__))
#
# # NOTE the script this loads uses simple evaluate
# # TODO potentially test both the helper script and the normal script
# sys.path.append(f"{MODULE_DIR}/../scripts")
# model_loader = importlib.import_module("requests_caching")
# run_model_for_task_caching = model_loader.run_model_for_task_caching
#
# os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "1"
# DEFAULT_TASKS = ["lambada_openai", "sciq"]
#
#
# @pytest.fixture(autouse=True)
# def setup_and_teardown():
# # Setup
# torch.use_deterministic_algorithms(False)
# clear_cache()
# # Yields control back to the test function
# yield
# # Cleanup here
#
#
# def clear_cache():
# if os.path.exists(PATH):
# cache_files = os.listdir(PATH)
# for file in cache_files:
# file_path = f"{PATH}/{file}"
# os.unlink(file_path)
#
#
# # leaving tasks here to allow for the option to select specific task files
# def get_cache_files(tasks: Optional[List[str]] = None) -> Tuple[List[str], List[str]]:
# cache_files = os.listdir(PATH)
#
# file_task_names = []
#
# for file in cache_files:
# file_without_prefix = file.split("-")[1]
# file_without_prefix_and_suffix = file_without_prefix.split(".")[0]
# file_task_names.extend([file_without_prefix_and_suffix])
#
# return cache_files, file_task_names
#
#
# def assert_created(tasks: List[str], file_task_names: List[str]):
# tasks.sort()
# file_task_names.sort()
#
# assert tasks == file_task_names
#
#
# @pytest.mark.parametrize("tasks", [DEFAULT_TASKS])
# def requests_caching_true(tasks: List[str]):
# run_model_for_task_caching(tasks=tasks, cache_requests="true")
#
# cache_files, file_task_names = get_cache_files()
# print(file_task_names)
# assert_created(tasks=tasks, file_task_names=file_task_names)
#
#
# @pytest.mark.parametrize("tasks", [DEFAULT_TASKS])
# def requests_caching_refresh(tasks: List[str]):
# run_model_for_task_caching(tasks=tasks, cache_requests="true")
#
# timestamp_before_test = datetime.now().timestamp()
#
# run_model_for_task_caching(tasks=tasks, cache_requests="refresh")
#
# cache_files, file_task_names = get_cache_files()
#
# for file in cache_files:
# modification_time = os.path.getmtime(f"{PATH}/{file}")
# assert modification_time > timestamp_before_test
#
# tasks.sort()
# file_task_names.sort()
#
# assert tasks == file_task_names
#
#
# @pytest.mark.parametrize("tasks", [DEFAULT_TASKS])
# def requests_caching_delete(tasks: List[str]):
# # populate the data first, rerun this test within this test for additional confidence
# # test_requests_caching_true(tasks=tasks)
#
# run_model_for_task_caching(tasks=tasks, cache_requests="delete")
#
# cache_files, file_task_names = get_cache_files()
#
# assert len(cache_files) == 0
#
#
# # useful for locally running tests through the debugger
# if __name__ == "__main__":
#
# def run_tests():
# tests = [
# # test_requests_caching_true,
# # test_requests_caching_refresh,
# # test_requests_caching_delete,
# ]
# # Lookups of global names within a loop is inefficient, so copy to a local variable outside of the loop first
# default_tasks = DEFAULT_TASKS
# for test_func in tests:
# clear_cache()
# test_func(tasks=default_tasks)
#
# print("Tests pass")
#
# run_tests()
......@@ -13,7 +13,7 @@ from pathlib import Path
import pytest
from lm_eval.tasks._task_index import TaskIndex, TaskKind
from lm_eval.tasks.index import Kind, TaskIndex
@pytest.fixture
......@@ -33,28 +33,28 @@ def yaml_file(temp_dir):
return _create_yaml
class TestTaskKindOf:
class TestKindOf:
"""Tests for identifying task configuration types."""
def test_kind_of_task(self):
"""Single task with string name."""
cfg = {"task": "my_task", "dataset_path": "data"}
assert TaskIndex._kind_of(cfg) == TaskKind.TASK
assert TaskIndex._kind_of(cfg) == Kind.TASK
def test_kind_of_group(self):
"""Group has task as list."""
cfg = {"task": ["task1", "task2"], "group": "my_group"}
assert TaskIndex._kind_of(cfg) == TaskKind.GROUP
assert TaskIndex._kind_of(cfg) == Kind.GROUP
def test_kind_of_py_task(self):
"""Python task has class field."""
cfg = {"task": "my_task", "class": "tasks.MyTask"}
assert TaskIndex._kind_of(cfg) == TaskKind.PY_TASK
assert TaskIndex._kind_of(cfg) == Kind.PY_TASK
def test_kind_of_task_list(self):
"""Task list has task_list field."""
cfg = {"task_list": ["task1", "task2"]}
assert TaskIndex._kind_of(cfg) == TaskKind.TASK_LIST
assert TaskIndex._kind_of(cfg) == Kind.TASK_LIST
def test_kind_of_unknown(self):
"""Unknown config raises ValueError."""
......@@ -75,7 +75,7 @@ class TestIterYamlFiles:
(temp_dir / "other.txt").touch()
builder = TaskIndex()
yaml_files = list(builder._iter_yaml_files())
yaml_files = list(builder._iter_yaml_files(temp_dir))
assert len(yaml_files) == 2
names = {f.name for f in yaml_files}
......@@ -90,7 +90,7 @@ class TestIterYamlFiles:
(temp_dir / ".ipynb_checkpoints" / "also_ignored.yaml").touch()
builder = TaskIndex()
yaml_files = list(builder._iter_yaml_files())
yaml_files = list(builder._iter_yaml_files(temp_dir))
assert len(yaml_files) == 1
assert yaml_files[0].name == "task.yaml"
......@@ -111,7 +111,7 @@ class TestProcessCfg:
assert "my_task" in index
entry = index["my_task"]
assert entry.name == "my_task"
assert entry.kind == TaskKind.TASK
assert entry.kind == Kind.TASK
assert entry.yaml_path == path
assert entry.tags == {"tag1", "tag2"}
......@@ -127,7 +127,7 @@ class TestProcessCfg:
assert "my_group" in index
entry = index["my_group"]
assert entry.name == "my_group"
assert entry.kind == TaskKind.GROUP
assert entry.kind == Kind.GROUP
assert entry.yaml_path == path
assert entry.tags == {"grp_tag"}
......@@ -143,7 +143,7 @@ class TestProcessCfg:
assert "py_task" in index
entry = index["py_task"]
assert entry.name == "py_task"
assert entry.kind == TaskKind.PY_TASK
assert entry.kind == Kind.PY_TASK
assert entry.yaml_path is None # Python tasks don't store yaml_path
assert entry.tags == {"py_tag"}
......@@ -181,14 +181,14 @@ class TestProcessCfg:
# Task without tags
assert "task1" in index
task1 = index["task1"]
assert task1.kind == TaskKind.TASK
assert task1.kind == Kind.TASK
assert task1.yaml_path == path
assert task1.tags == set()
# Task with tags
assert "task2" in index
task2 = index["task2"]
assert task2.kind == TaskKind.TASK
assert task2.kind == Kind.TASK
assert task2.yaml_path == path
assert task2.tags == {"tag1", "tag2"}
......@@ -205,7 +205,7 @@ class TestRegisterTags:
assert "my_tag" in index
tag_entry = index["my_tag"]
assert tag_entry.kind == TaskKind.TAG
assert tag_entry.kind == Kind.TAG
assert tag_entry.yaml_path is None
assert "task1" in tag_entry.tags # TAG entries use tags set for task names
......@@ -252,7 +252,7 @@ class TestBuild:
assert len(index) == 1
assert "my_task" in index
assert index["my_task"].kind == TaskKind.TASK
assert index["my_task"].kind == Kind.TASK
def test_build_mixed_types(self, temp_dir, yaml_file):
"""Discovers various task types."""
......@@ -283,12 +283,12 @@ class TestBuild:
assert "common" in index # Tag entry
# Check types
assert index["task1"].kind == TaskKind.TASK
assert index["group1"].kind == TaskKind.GROUP
assert index["task2"].kind == TaskKind.TASK
assert index["task3"].kind == TaskKind.TASK
assert index["py_task"].kind == TaskKind.PY_TASK
assert index["common"].kind == TaskKind.TAG
assert index["task1"].kind == Kind.TASK
assert index["group1"].kind == Kind.GROUP
assert index["task2"].kind == Kind.TASK
assert index["task3"].kind == Kind.TASK
assert index["py_task"].kind == Kind.PY_TASK
assert index["common"].kind == Kind.TAG
# Check tag has both tasks
assert index["common"].tags == {"task1", "task3"}
......
......@@ -64,10 +64,10 @@ def test_python_task_inclusion(
verbosity="INFO", include_path=str(custom_task_files_dir)
)
# check if python tasks enters the global task_index
assert custom_task_name in task_manager.task_index
assert custom_task_name in task_manager._index
# check if subtask is present
assert custom_task_name in task_manager.all_subtasks
assert custom_task_name in task_manager._index
# check if tag is present
assert custom_task_tag in task_manager.all_tags
assert custom_task_tag in task_manager._index
# check if it can be loaded by tag (custom_task_tag)
assert custom_task_name in task_manager.load_task_or_group(custom_task_tag)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment