"vscode:/vscode.git/clone" did not exist on "a7248f6ea8fc277b81916dffb238cdcb1f0d9c58"
Unverified Commit 8b5c5c13 authored by Kiersten Stokes's avatar Kiersten Stokes Committed by GitHub
Browse files

Add test for a simple Unitxt task (#2742)

* Add a test for a custom unitxt task

* Update task.py to bring in line with breaking change in v1.17.2

* Fix lint
parent d35008f1
......@@ -34,6 +34,15 @@ def assert_unitxt_installed():
"Please install unitxt via 'pip install unitxt'. For more information see: https://www.unitxt.ai/"
)
from unitxt import __version__ as unitxt_version
# Function argument change due to https://github.com/IBM/unitxt/pull/1564
unitxt_version = tuple(map(int, (unitxt_version.split("."))))
if unitxt_version < (1, 17, 2):
raise Exception(
"Please install a more recent version of unitxt via 'pip install --upgrade unitxt' to avoid errors due to breaking changes"
)
def score(items, metric):
predictions, references = zip(*items)
......@@ -69,7 +78,7 @@ class Unitxt(ConfigurableTask):
assert_unitxt_installed()
from unitxt import load_dataset
self.dataset = load_dataset(self.DATASET_NAME, disable_cache=False)
self.dataset = load_dataset(self.DATASET_NAME, use_cache=True)
def has_training_docs(self):
return "train" in self.dataset
......
......@@ -58,7 +58,7 @@ Repository = "https://github.com/EleutherAI/lm-evaluation-harness"
[project.optional-dependencies]
api = ["requests", "aiohttp", "tenacity", "tqdm", "tiktoken"]
dev = ["pytest", "pytest-cov", "pytest-xdist", "pre-commit", "mypy"]
dev = ["pytest", "pytest-cov", "pytest-xdist", "pre-commit", "mypy", "unitxt"]
deepsparse = ["deepsparse-nightly[llm]>=1.8.0.20240404"]
gptq = ["auto-gptq[triton]>=0.6.0"]
hf_transfer = ["hf_transfer"]
......
......@@ -13,18 +13,29 @@ from .utils import new_tasks
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
os.environ["TOKENIZERS_PARALLELISM"] = "false"
task_manager = tasks.TaskManager()
# Default Task
TASKS = ["arc_easy"]
def task_class():
def get_new_tasks_else_default():
"""
Check if any modifications have been made to built-in tasks and return
the list, otherwise return the default task list
"""
global TASKS
# CI: new_tasks checks if any modifications have been made
task_classes = new_tasks()
# Check if task_classes is empty
task_classes = task_classes if task_classes else TASKS
res = tasks.get_task_dict(task_classes, task_manager)
return task_classes if task_classes else TASKS
def task_class(task_names=None, task_manager=None) -> ConfigurableTask:
"""
Convert a list of task names to a list of ConfigurableTask instances
"""
if task_manager is None:
task_manager = tasks.TaskManager()
res = tasks.get_task_dict(task_names, task_manager)
res = [x.task for x in get_task_list(res)]
return res
......@@ -36,8 +47,11 @@ def limit() -> int:
# Tests
@pytest.mark.parametrize("task_class", task_class(), ids=lambda x: f"{x.config.task}")
class TestNewTasks:
class BaseTasks:
"""
Base class for testing tasks
"""
def test_download(self, task_class: ConfigurableTask):
task_class.download()
assert task_class.dataset is not None
......@@ -140,3 +154,57 @@ class TestNewTasks:
for doc in arr
]
assert len(requests) == limit if limit else True
@pytest.mark.parametrize(
"task_class",
task_class(get_new_tasks_else_default()),
ids=lambda x: f"{x.config.task}",
)
class TestNewTasksElseDefault(BaseTasks):
"""
Test class parameterized with a list of new/modified tasks
(or a set of default tasks if none have been modified)
"""
@pytest.mark.parametrize(
"task_class",
task_class(
["arc_easy_unitxt"], tasks.TaskManager(include_path="./tests/testconfigs")
),
ids=lambda x: f"{x.config.task}",
)
class TestUnitxtTasks(BaseTasks):
"""
Test class for Unitxt tasks parameterized with a small custom
task as described here:
https://www.unitxt.ai/en/latest/docs/lm_eval.html
"""
def test_check_training_docs(self, task_class: ConfigurableTask):
if task_class.has_training_docs():
assert task_class.dataset["train"] is not None
def test_check_validation_docs(self, task_class):
if task_class.has_validation_docs():
assert task_class.dataset["validation"] is not None
def test_check_test_docs(self, task_class):
task = task_class
if task.has_test_docs():
assert task.dataset["test"] is not None
def test_doc_to_text(self, task_class, limit: int):
task = task_class
arr = (
list(islice(task.test_docs(), limit))
if task.has_test_docs()
else list(islice(task.validation_docs(), limit))
)
_array = [task.doc_to_text(doc) for doc in arr]
if not task.multiple_input:
for x in _array:
assert isinstance(x, str)
else:
pass
task: arc_easy_unitxt
include: ../../lm_eval/tasks/unitxt/unitxt
recipe: card=cards.ai2_arc.arc_easy,template=templates.qa.multiple_choice.open.all
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment