Unverified Commit b952a206 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #746 from EleutherAI/StellaAthena-patch-2

Update triviaqa.py
parents 5a49b2a3 f08f7c79
......@@ -28,7 +28,7 @@ _CITATION = """
class TriviaQA(Task):
VERSION = 2
VERSION = 3
DATASET_PATH = "trivia_qa"
DATASET_NAME = "rc.nocontext"
......@@ -62,16 +62,6 @@ class TriviaQA(Task):
def doc_to_target(self, doc):
return " " + doc["answer"]["value"]
def _remove_prefixes(self, aliases):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
# we can do this because if the prefix is acceptable by isgreedy, we can stop looking
aliases.sort()
ret = [aliases[0]]
for alias in aliases[1:]:
if not alias.startswith(ret[-1]):
ret.append(alias)
return ret
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
......@@ -87,7 +77,7 @@ class TriviaQA(Task):
def process_results(self, doc, results):
continuation = results[0].strip().lower().translate(str.maketrans('', '', string.punctuation))
list_of_candidates = [alias.lower().translate(str.maketrans('', '', string.punctuation)) for alias in self._remove_prefixes(doc["answer"]["aliases"])]
list_of_candidates = [alias.lower().translate(str.maketrans('', '', string.punctuation)) for alias in doc["answer"]["aliases"]]
return {"em": float(continuation in list_of_candidates)}
def aggregation(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment