Unverified Commit b952a206 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #746 from EleutherAI/StellaAthena-patch-2

Update triviaqa.py
parents 5a49b2a3 f08f7c79
...@@ -28,7 +28,7 @@ _CITATION = """ ...@@ -28,7 +28,7 @@ _CITATION = """
class TriviaQA(Task): class TriviaQA(Task):
VERSION = 2 VERSION = 3
DATASET_PATH = "trivia_qa" DATASET_PATH = "trivia_qa"
DATASET_NAME = "rc.nocontext" DATASET_NAME = "rc.nocontext"
...@@ -62,16 +62,6 @@ class TriviaQA(Task): ...@@ -62,16 +62,6 @@ class TriviaQA(Task):
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc["answer"]["value"] return " " + doc["answer"]["value"]
def _remove_prefixes(self, aliases):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
# we can do this because if the prefix is acceptable by isgreedy, we can stop looking
aliases.sort()
ret = [aliases[0]]
for alias in aliases[1:]:
if not alias.startswith(ret[-1]):
ret.append(alias)
return ret
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
...@@ -87,7 +77,7 @@ class TriviaQA(Task): ...@@ -87,7 +77,7 @@ class TriviaQA(Task):
def process_results(self, doc, results): def process_results(self, doc, results):
continuation = results[0].strip().lower().translate(str.maketrans('', '', string.punctuation)) continuation = results[0].strip().lower().translate(str.maketrans('', '', string.punctuation))
list_of_candidates = [alias.lower().translate(str.maketrans('', '', string.punctuation)) for alias in self._remove_prefixes(doc["answer"]["aliases"])] list_of_candidates = [alias.lower().translate(str.maketrans('', '', string.punctuation)) for alias in doc["answer"]["aliases"]]
return {"em": float(continuation in list_of_candidates)} return {"em": float(continuation in list_of_candidates)}
def aggregation(self): def aggregation(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment