Unverified Commit f3a0b554 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

warning for "chat" pretrained; disable buggy evalita configs (#3127)

* check for chat for warning

* add test

* remove yaml extension from some evalita configs

* move unitxt to own test script

* fix CI test
parent ab3acc73
...@@ -154,15 +154,23 @@ def simple_evaluate( ...@@ -154,15 +154,23 @@ def simple_evaluate(
"Either 'limit' or 'samples' must be None, but both are not None." "Either 'limit' or 'samples' must be None, but both are not None."
) )
_NEEDS_CHAT_TEMPLATE = ("inst", "chat")
if ( if (
(isinstance(model_args, str) and "inst" in model_args.lower()) (
isinstance(model_args, str)
and any(kw in model_args.lower() for kw in _NEEDS_CHAT_TEMPLATE)
)
or ( or (
isinstance(model_args, dict) isinstance(model_args, dict)
and any("inst" in str(v).lower() for v in model_args.values()) and any(
any(kw in str(v).lower() for kw in _NEEDS_CHAT_TEMPLATE)
for v in model_args.values()
)
) )
) and not apply_chat_template: ) and not apply_chat_template:
eval_logger.warning( eval_logger.warning(
"Model appears to be an instruct variant but chat template is not applied. Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`)." "Model appears to be an instruct or chat variant but chat template is not applied. "
"Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`)."
) )
if delete_requests_cache: if delete_requests_cache:
......
...@@ -141,7 +141,7 @@ class MultiChoiceRegexFilter(RegexFilter): ...@@ -141,7 +141,7 @@ class MultiChoiceRegexFilter(RegexFilter):
""" """
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
- step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response. - step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
- step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices. - step 2 : We parse the choice with regex: r'\s*([A-?])', where ? varies by number of choices.
group_select: Selects the (group_select)th match from the findall result. group_select: Selects the (group_select)th match from the findall result.
ignore_case: Ignores the case during step 1 matching ignore_case: Ignores the case during step 1 matching
ignore_punctuation: Remove the punctuation during step 1 matching ignore_punctuation: Remove the punctuation during step 1 matching
......
...@@ -46,7 +46,6 @@ def limit() -> int: ...@@ -46,7 +46,6 @@ def limit() -> int:
return 10 return 10
# Tests
class BaseTasks: class BaseTasks:
""" """
Base class for testing tasks Base class for testing tasks
...@@ -166,45 +165,3 @@ class TestNewTasksElseDefault(BaseTasks): ...@@ -166,45 +165,3 @@ class TestNewTasksElseDefault(BaseTasks):
Test class parameterized with a list of new/modified tasks Test class parameterized with a list of new/modified tasks
(or a set of default tasks if none have been modified) (or a set of default tasks if none have been modified)
""" """
@pytest.mark.parametrize(
"task_class",
task_class(
["arc_easy_unitxt"], tasks.TaskManager(include_path="./tests/testconfigs")
),
ids=lambda x: f"{x.config.task}",
)
class TestUnitxtTasks(BaseTasks):
"""
Test class for Unitxt tasks parameterized with a small custom
task as described here:
https://www.unitxt.ai/en/latest/docs/lm_eval.html
"""
def test_check_training_docs(self, task_class: ConfigurableTask):
if task_class.has_training_docs():
assert task_class.dataset["train"] is not None
def test_check_validation_docs(self, task_class):
if task_class.has_validation_docs():
assert task_class.dataset["validation"] is not None
def test_check_test_docs(self, task_class):
task = task_class
if task.has_test_docs():
assert task.dataset["test"] is not None
def test_doc_to_text(self, task_class, limit: int):
task = task_class
arr = (
list(islice(task.test_docs(), limit))
if task.has_test_docs()
else list(islice(task.validation_docs(), limit))
)
_array = [task.doc_to_text(doc) for doc in arr]
if not task.multiple_input:
for x in _array:
assert isinstance(x, str)
else:
pass
from itertools import islice
import pytest
from lm_eval import tasks as tasks
from lm_eval.api.task import ConfigurableTask
from tests.test_tasks import BaseTasks, task_class
@pytest.mark.parametrize(
"task_class",
task_class(
["arc_easy_unitxt"], tasks.TaskManager(include_path="./tests/testconfigs")
),
ids=lambda x: f"{x.config.task}",
)
class TestUnitxtTasks(BaseTasks):
"""
Test class for Unitxt tasks parameterized with a small custom
task as described here:
https://www.unitxt.ai/en/latest/docs/lm_eval.html
"""
def test_check_training_docs(self, task_class: ConfigurableTask):
if task_class.has_training_docs():
assert task_class.dataset["train"] is not None
def test_check_validation_docs(self, task_class):
if task_class.has_validation_docs():
assert task_class.dataset["validation"] is not None
def test_check_test_docs(self, task_class):
task = task_class
if task.has_test_docs():
assert task.dataset["test"] is not None
def test_doc_to_text(self, task_class, limit: int):
task = task_class
arr = (
list(islice(task.test_docs(), limit))
if task.has_test_docs()
else list(islice(task.validation_docs(), limit))
)
_array = [task.doc_to_text(doc) for doc in arr]
if not task.multiple_input:
for x in _array:
assert isinstance(x, str)
else:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment