Commit ba75c838 authored by baberabb's avatar baberabb
Browse files

changed task_class fixture

parent 3864b560
......@@ -4,11 +4,10 @@ from typing import List
import lm_eval.tasks as tasks
from lm_eval.api.task import ConfigurableTask
# Using fixtures to get the task class and limit
@pytest.fixture()
def task_class(task_name: List[str] = None) -> ConfigurableTask:
if task_name is None:
task_name = ["arc_easy"]
def task_class(task_name: List[str]) -> ConfigurableTask:
task_name = ["arc_easy"]
x = [cls for name, cls in tasks.TASK_REGISTRY.items() if name in task_name]
return x[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment