Commit 7d282b5f authored by cjlovering's avatar cjlovering
Browse files

Add PromptSourceTask to the updated tasks

parent 88745155
......@@ -12,7 +12,7 @@ Homepage: https://stanfordnlp.github.io/coqa/
import inspect
import transformers.data.metrics.squad_metrics as squad_metrics
import lm_eval.datasets.coqa.coqa
from lm_eval.base import Task, rf, mean
from lm_eval.base import PromptSourceTask, rf, mean
from itertools import zip_longest
......@@ -28,7 +28,7 @@ _CITATION = """
"""
class CoQA(Task):
class CoQA(PromptSourceTask):
VERSION = 1
DATASET_PATH = inspect.getfile(lm_eval.datasets.coqa.coqa)
DATASET_NAME = None
......
......@@ -18,7 +18,7 @@ import re
import string
import lm_eval.datasets.drop.drop
from scipy.optimize import linear_sum_assignment
from lm_eval.base import Task, rf
from lm_eval.base import PromptSourceTask, rf
from lm_eval.metrics import mean
......@@ -37,7 +37,7 @@ _CITATION = """
_ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE)
class DROP(Task):
class DROP(PromptSourceTask):
VERSION = 1
DATASET_PATH = inspect.getfile(lm_eval.datasets.drop.drop)
DATASET_NAME = None
......
......@@ -12,7 +12,7 @@ Homepage: https://www.cs.cmu.edu/~glai1/data/race/
import collections
import datasets
import numpy as np
from lm_eval.base import rf, Task
from lm_eval.base import PromptSourceTask, rf
from lm_eval.metrics import mean
......@@ -34,7 +34,7 @@ class each:
return list(map(self.f, other))
class RACE(Task):
class RACE(PromptSourceTask):
VERSION = 1
DATASET_PATH = "race"
DATASET_NAME = "high"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment