Unverified Commit c7572ba6 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Delete triviaqa.py

parent 025fa6e8
"""
TriviaQA: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension
https://arxiv.org/pdf/1705.03551.pdf
TriviaQA is a reading comprehension dataset containing over 650K question-answer-evidence
triples. TriviaQA includes 95K question-answer pairs authored by trivia enthusiasts
and independently gathered evidence documents, six per question on average, that provide
high quality distant supervision for answering the questions.
Homepage: https://nlp.cs.washington.edu/triviaqa/
"""
import inspect
# import lm_eval.datasets.triviaqa.triviaqa
import string
from lm_eval.api.task import Task
from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_task
from lm_eval.api.metrics import mean
_CITATION = """
@InProceedings{JoshiTriviaQA2017,
author = {Joshi, Mandar and Choi, Eunsol and Weld, Daniel S. and Zettlemoyer, Luke},
title = {TriviaQA: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension},
booktitle = {Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics},
month = {July},
year = {2017},
address = {Vancouver, Canada},
publisher = {Association for Computational Linguistics},
}
"""
@register_task("triviaqa")
class TriviaQA(Task):
VERSION = 1
DATASET_PATH = "trivia_qa" # inspect.getfile(lm_eval.datasets.triviaqa.triviaqa)
DATASET_NAME = "unfiltered.nocontext"
OUTPUT_TYPE = "greedy_until"
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
return self.dataset["train"]
def validation_docs(self):
return self.dataset["validation"]
def test_docs(self):
raise NotImplementedError()
def doc_to_text(self, doc):
return f"Q: {doc['question']}\nA:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["question"]
def doc_to_target(self, doc):
return " " + doc["answer"]["value"]
def _remove_prefixes(self, aliases):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
# we can do this because if the prefix is acceptable by isgreedy, we can stop looking
aliases.sort()
ret = [aliases[0]]
for alias in aliases[1:]:
if not alias.startswith(ret[-1]):
ret.append(alias)
return ret
def construct_requests(self, doc, ctx, **kwargs):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
continuation = Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=(
ctx,
{
"until": ["\n", ".", ","],
"do_sample": False,
},
),
idx=0,
**kwargs,
)
return continuation
def process_results(self, doc, results):
continuation = (
results[0]
.strip()
.lower()
.translate(str.maketrans("", "", string.punctuation))
)
list_of_candidates = [
alias.lower().translate(str.maketrans("", "", string.punctuation))
for alias in self._remove_prefixes(doc["answer"]["aliases"])
]
return {"em": float(continuation in list_of_candidates)}
def aggregation(self):
return {
"em": mean,
}
def higher_is_better(self):
return {"em": True}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment