Commit ba75c838 authored by baberabb's avatar baberabb
Browse files

changed task_class fixture

parent 3864b560
...@@ -4,11 +4,10 @@ from typing import List ...@@ -4,11 +4,10 @@ from typing import List
import lm_eval.tasks as tasks import lm_eval.tasks as tasks
from lm_eval.api.task import ConfigurableTask from lm_eval.api.task import ConfigurableTask
# Using fixtures to get the task class and limit
@pytest.fixture() @pytest.fixture()
def task_class(task_name: List[str] = None) -> ConfigurableTask: def task_class(task_name: List[str]) -> ConfigurableTask:
if task_name is None: task_name = ["arc_easy"]
task_name = ["arc_easy"]
x = [cls for name, cls in tasks.TASK_REGISTRY.items() if name in task_name] x = [cls for name, cls in tasks.TASK_REGISTRY.items() if name in task_name]
return x[0] return x[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment