Commit 8c419c83 authored by seopbo's avatar seopbo
Browse files

Fix triviaqa task

parent 8cff2bea
...@@ -11,10 +11,10 @@ Homepage: https://nlp.cs.washington.edu/triviaqa/ ...@@ -11,10 +11,10 @@ Homepage: https://nlp.cs.washington.edu/triviaqa/
""" """
import inspect import inspect
import lm_eval.datasets.triviaqa.triviaqa import lm_eval.datasets.triviaqa.triviaqa
import string
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean from lm_eval.metrics import mean
_CITATION = """ _CITATION = """
@InProceedings{JoshiTriviaQA2017, @InProceedings{JoshiTriviaQA2017,
author = {Joshi, Mandar and Choi, Eunsol and Weld, Daniel S. and Zettlemoyer, Luke}, author = {Joshi, Mandar and Choi, Eunsol and Weld, Daniel S. and Zettlemoyer, Luke},
...@@ -74,19 +74,27 @@ class TriviaQA(Task): ...@@ -74,19 +74,27 @@ class TriviaQA(Task):
return ret return ret
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ret = [] """Uses RequestFactory to construct Requests and returns an iterable of
for alias in self._remove_prefixes(doc["answer"]["aliases"]): Requests which will be sent to the LM.
_, is_prediction = rf.loglikelihood(ctx, " " + alias) :param doc:
ret.append(is_prediction) The document as returned from training_docs, validation_docs, or test_docs.
return ret :param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
continuation = rf.greedy_until(ctx, {"until": ["\n", ".", ","]})
return continuation
def process_results(self, doc, results): def process_results(self, doc, results):
return {"acc": float(any(results))} continuation = results[0].strip().lower().translate(str.maketrans('', '', string.punctuation))
list_of_candidates = [alias.lower().translate(str.maketrans('', '', string.punctuation)) for alias in self._remove_prefixes(doc["answer"]["aliases"])]
return {"em": float(continuation in list_of_candidates)}
def aggregation(self): def aggregation(self):
return { return {
"acc": mean, "em": mean,
} }
def higher_is_better(self): def higher_is_better(self):
return {"acc": True} return {"em": True}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment