Unverified Commit 0dd45190 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #610 from Vermeille/patch-1

[triviaqa] The ground truth must be a *substring* of the generated an…
parents 5e56fbf6 0570991a
......@@ -28,7 +28,7 @@ _CITATION = """
class TriviaQA(Task):
VERSION = 2
VERSION = 3
DATASET_PATH = "trivia_qa"
DATASET_NAME = "rc.nocontext"
......@@ -86,9 +86,11 @@ class TriviaQA(Task):
return continuation
def process_results(self, doc, results):
continuation = results[0].strip().lower().translate(str.maketrans('', '', string.punctuation))
list_of_candidates = [alias.lower().translate(str.maketrans('', '', string.punctuation)) for alias in self._remove_prefixes(doc["answer"]["aliases"])]
return {"em": float(continuation in list_of_candidates)}
generated = results[0].strip().lower().translate(str.maketrans('', '', string.punctuation))
list_of_truth_candidates = [alias.lower().translate(str.maketrans('', '', string.punctuation)) for alias in self._remove_prefixes(doc["answer"]["aliases"])]
def match(candidate):
return candidate in generated
return {"em": float(any(match(candidate) for candidate in list_of_truth_candidates))}
def aggregation(self):
return {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment