Commit 224b0854 authored by lintangsutawika's avatar lintangsutawika
Browse files

change import origin

parent 8171906d
...@@ -10,11 +10,12 @@ high quality distant supervision for answering the questions. ...@@ -10,11 +10,12 @@ high quality distant supervision for answering the questions.
Homepage: https://nlp.cs.washington.edu/triviaqa/ Homepage: https://nlp.cs.washington.edu/triviaqa/
""" """
import inspect import inspect
# import lm_eval.datasets.triviaqa.triviaqa # import lm_eval.datasets.triviaqa.triviaqa
import string import string
from lm_eval.api.task import Task from lm_eval.api.task import Task
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.register import register_task from lm_eval.api.registry import register_task
from lm_eval.api.metrics import mean from lm_eval.api.metrics import mean
_CITATION = """ _CITATION = """
...@@ -29,10 +30,11 @@ _CITATION = """ ...@@ -29,10 +30,11 @@ _CITATION = """
} }
""" """
@register_task("triviaqa") @register_task("triviaqa")
class TriviaQA(Task): class TriviaQA(Task):
VERSION = 1 VERSION = 1
DATASET_PATH = "trivia_qa" #inspect.getfile(lm_eval.datasets.triviaqa.triviaqa) DATASET_PATH = "trivia_qa" # inspect.getfile(lm_eval.datasets.triviaqa.triviaqa)
DATASET_NAME = "unfiltered.nocontext" DATASET_NAME = "unfiltered.nocontext"
OUTPUT_TYPE = "greedy_until" OUTPUT_TYPE = "greedy_until"
...@@ -90,18 +92,29 @@ class TriviaQA(Task): ...@@ -90,18 +92,29 @@ class TriviaQA(Task):
continuation = Instance( continuation = Instance(
request_type=self.OUTPUT_TYPE, request_type=self.OUTPUT_TYPE,
doc=doc, doc=doc,
arguments=(ctx, { arguments=(
ctx,
{
"until": ["\n", ".", ","], "until": ["\n", ".", ","],
"do_sample": False, "do_sample": False,
}), },
),
idx=0, idx=0,
**kwargs, **kwargs,
) )
return continuation return continuation
def process_results(self, doc, results): def process_results(self, doc, results):
continuation = results[0].strip().lower().translate(str.maketrans('', '', string.punctuation)) continuation = (
list_of_candidates = [alias.lower().translate(str.maketrans('', '', string.punctuation)) for alias in self._remove_prefixes(doc["answer"]["aliases"])] results[0]
.strip()
.lower()
.translate(str.maketrans("", "", string.punctuation))
)
list_of_candidates = [
alias.lower().translate(str.maketrans("", "", string.punctuation))
for alias in self._remove_prefixes(doc["answer"]["aliases"])
]
return {"em": float(continuation in list_of_candidates)} return {"em": float(continuation in list_of_candidates)}
def aggregation(self): def aggregation(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment