"vscode:/vscode.git/clone" did not exist on "81561f8e2d55d105aabbe0eab1b3b33f4fc04b0b"
Commit 8c419c83 authored by seopbo's avatar seopbo
Browse files

Fix triviaqa task

parent 8cff2bea
......@@ -11,10 +11,10 @@ Homepage: https://nlp.cs.washington.edu/triviaqa/
"""
import inspect
import lm_eval.datasets.triviaqa.triviaqa
import string
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
_CITATION = """
@InProceedings{JoshiTriviaQA2017,
author = {Joshi, Mandar and Choi, Eunsol and Weld, Daniel S. and Zettlemoyer, Luke},
......@@ -74,19 +74,27 @@ class TriviaQA(Task):
return ret
def construct_requests(self, doc, ctx):
ret = []
for alias in self._remove_prefixes(doc["answer"]["aliases"]):
_, is_prediction = rf.loglikelihood(ctx, " " + alias)
ret.append(is_prediction)
return ret
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
continuation = rf.greedy_until(ctx, {"until": ["\n", ".", ","]})
return continuation
def process_results(self, doc, results):
return {"acc": float(any(results))}
continuation = results[0].strip().lower().translate(str.maketrans('', '', string.punctuation))
list_of_candidates = [alias.lower().translate(str.maketrans('', '', string.punctuation)) for alias in self._remove_prefixes(doc["answer"]["aliases"])]
return {"em": float(continuation in list_of_candidates)}
def aggregation(self):
return {
"acc": mean,
"em": mean,
}
def higher_is_better(self):
return {"acc": True}
return {"em": True}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment