Commit bd028848 authored by Baber's avatar Baber
Browse files

Merge branch 'main' into metrics

# Conflicts:
#	tests/test_tasks.py
parents 6e48110e 56def33d
......@@ -11,7 +11,9 @@ try:
except (ModuleNotFoundError, ImportError):
raise ModuleNotFoundError(
"Please install evaluation metrics via pip install evaluate and pip install bert-score",
"Please install evaluation metrics via pip install evaluate bert-score "
"rouge_score>=0.1.2 nltk absl-py "
"git+https://github.com/google-research/bleurt.git"
)
except Exception as e:
raise RuntimeError(
......
......@@ -11,7 +11,9 @@ try:
except (ModuleNotFoundError, ImportError):
raise ModuleNotFoundError(
"Please install evaluation metrics via pip install evaluate and pip install bert-score",
"Please install evaluation metrics via pip install evaluate bert-score "
"rouge_score>=0.1.2 nltk absl-py "
"git+https://github.com/google-research/bleurt.git"
)
except Exception as e:
raise RuntimeError(
......
......@@ -15,7 +15,9 @@ try:
except (ModuleNotFoundError, ImportError):
raise ModuleNotFoundError(
"Please install evaluation metrics via pip install evaluate and pip install bert-score",
"Please install evaluation metrics via pip install evaluate bert-score "
"rouge_score>=0.1.2 nltk absl-py radgraph"
"git+https://github.com/google-research/bleurt.git"
)
except Exception as e:
raise RuntimeError(
......
......@@ -11,7 +11,9 @@ try:
except (ModuleNotFoundError, ImportError):
raise ModuleNotFoundError(
"Please install evaluation metrics via pip install evaluate and pip install bert-score",
"Please install evaluation metrics via pip install evaluate bert-score "
"rouge_score>=0.1.2 nltk absl-py "
"git+https://github.com/google-research/bleurt.git"
)
except Exception as e:
raise RuntimeError(
......
......@@ -12,7 +12,9 @@ try:
except (ModuleNotFoundError, ImportError):
raise ModuleNotFoundError(
"Please install evaluation metrics via pip install evaluate and pip install bert-score",
"Please install evaluation metrics via pip install evaluate bert-score "
"rouge_score>=0.1.2 nltk absl-py "
"git+https://github.com/google-research/bleurt.git"
)
except Exception as e:
raise RuntimeError(
......
......@@ -579,10 +579,11 @@ def hash_dict_images(data_dict):
dict: A new dictionary with the same structure as `data_dict`, but with all
bytes and PIL.Image.Image objects replaced by their hashes.
"""
from PIL import Image
def _process_value(value):
# Bytes -> hash
from PIL import Image
if isinstance(value, (bytes, bytearray)):
return convert_bytes_to_hash(value)
# PIL Image -> hash
......@@ -603,4 +604,8 @@ def hash_dict_images(data_dict):
if not isinstance(data_dict, dict):
raise TypeError("Input must be a dictionary")
return {key: _process_value(val) for key, val in data_dict.items()}
return (
{key: _process_value(val) for key, val in data_dict.items()}
if importlib.util.find_spec("PIL")
else data_dict
)
......@@ -46,12 +46,7 @@ def limit() -> int:
return 10
@pytest.mark.parametrize(
"task_class",
task_class(get_new_tasks_else_default()),
ids=lambda x: f"{x.config.task}",
)
class TestBaseTasks:
class BaseTasks:
"""
Base class for testing tasks
"""
......@@ -165,50 +160,8 @@ class TestBaseTasks:
task_class(get_new_tasks_else_default()),
ids=lambda x: f"{x.config.task}",
)
class TestNewTasksElseDefault(TestBaseTasks):
class TestNewTasksElseDefault(BaseTasks):
"""
Test class parameterized with a list of new/modified tasks
(or a set of default tasks if none have been modified)
"""
@pytest.mark.parametrize(
"task_class",
task_class(
["arc_easy_unitxt"], tasks.TaskManager(include_path="./tests/testconfigs")
),
ids=lambda x: f"{x.config.task}",
)
class TestUnitxtTasks(TestBaseTasks):
"""
Test class for Unitxt tasks parameterized with a small custom
task as described here:
https://www.unitxt.ai/en/latest/docs/lm_eval.html
"""
def test_check_training_docs(self, task_class: ConfigurableTask):
if task_class.has_training_docs():
assert task_class.dataset["train"] is not None
def test_check_validation_docs(self, task_class):
if task_class.has_validation_docs():
assert task_class.dataset["validation"] is not None
def test_check_test_docs(self, task_class):
task = task_class
if task.has_test_docs():
assert task.dataset["test"] is not None
def test_doc_to_text(self, task_class, limit: int):
task = task_class
arr = (
list(islice(task.test_docs(), limit))
if task.has_test_docs()
else list(islice(task.validation_docs(), limit))
)
_array = [task.doc_to_text(doc) for doc in arr]
if not task.multiple_input:
for x in _array:
assert isinstance(x, str)
else:
pass
from itertools import islice
import pytest
from lm_eval import tasks as tasks
from lm_eval.api.task import ConfigurableTask
from tests.test_tasks import BaseTasks, task_class
@pytest.mark.parametrize(
"task_class",
task_class(
["arc_easy_unitxt"], tasks.TaskManager(include_path="./tests/testconfigs")
),
ids=lambda x: f"{x.config.task}",
)
class TestUnitxtTasks(BaseTasks):
"""
Test class for Unitxt tasks parameterized with a small custom
task as described here:
https://www.unitxt.ai/en/latest/docs/lm_eval.html
"""
def test_check_training_docs(self, task_class: ConfigurableTask):
if task_class.has_training_docs():
assert task_class.dataset["train"] is not None
def test_check_validation_docs(self, task_class):
if task_class.has_validation_docs():
assert task_class.dataset["validation"] is not None
def test_check_test_docs(self, task_class):
task = task_class
if task.has_test_docs():
assert task.dataset["test"] is not None
def test_doc_to_text(self, task_class, limit: int):
task = task_class
arr = (
list(islice(task.test_docs(), limit))
if task.has_test_docs()
else list(islice(task.validation_docs(), limit))
)
_array = [task.doc_to_text(doc) for doc in arr]
if not task.multiple_input:
for x in _array:
assert isinstance(x, str)
else:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment