Unverified Commit b018a7d5 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #525 from seopbo/fix-triviaqa

Fix triviaqa task
parents 137d0c43 50cf0984
---
dataset_info:
features:
- name: question_id
dtype: string
- name: question_source
dtype: string
- name: question
dtype: string
- name: answer
struct:
- name: aliases
sequence: string
- name: value
dtype: string
- name: search_results
sequence:
- name: description
dtype: string
- name: filename
dtype: string
- name: rank
dtype: int32
- name: title
dtype: string
- name: url
dtype: string
- name: search_context
dtype: string
config_name: triviaqa
splits:
- name: train
num_bytes: 1270894387
num_examples: 87622
- name: validation
num_bytes: 163755044
num_examples: 11313
download_size: 632549060
dataset_size: 1434649431
---
{"triviaqa": {"description": "TriviaQA is a reading comprehension dataset containing over 650K question-answer-evidence\ntriples. TriviaQA includes 95K question-answer pairs authored by trivia enthusiasts\nand independently gathered evidence documents, six per question on average, that provide\nhigh quality distant supervision for answering the questions.\n", "citation": "@InProceedings{JoshiTriviaQA2017,\n author = {Joshi, Mandar and Choi, Eunsol and Weld, Daniel S. and Zettlemoyer, Luke},\n title = {TriviaQA: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension},\n booktitle = {Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics},\n month = {July},\n year = {2017},\n address = {Vancouver, Canada},\n publisher = {Association for Computational Linguistics},\n}\n", "homepage": "https://nlp.cs.washington.edu/triviaqa/", "license": "Apache License 2.0", "features": {"question_id": {"dtype": "string", "id": null, "_type": "Value"}, "question_source": {"dtype": "string", "id": null, "_type": "Value"}, "question": {"dtype": "string", "id": null, "_type": "Value"}, "answer": {"aliases": {"feature": {"dtype": "string", "id": null, "_type": "Value"}, "length": -1, "id": null, "_type": "Sequence"}, "value": {"dtype": "string", "id": null, "_type": "Value"}}, "search_results": {"feature": {"description": {"dtype": "string", "id": null, "_type": "Value"}, "filename": {"dtype": "string", "id": null, "_type": "Value"}, "rank": {"dtype": "int32", "id": null, "_type": "Value"}, "title": {"dtype": "string", "id": null, "_type": "Value"}, "url": {"dtype": "string", "id": null, "_type": "Value"}, "search_context": {"dtype": "string", "id": null, "_type": "Value"}}, "length": -1, "id": null, "_type": "Sequence"}}, "post_processed": null, "supervised_keys": null, "task_templates": null, "builder_name": "triviaqa", "config_name": "triviaqa", "version": {"version_str": "0.0.1", "description": null, "major": 0, "minor": 0, "patch": 1}, "splits": {"train": {"name": "train", "num_bytes": 1271393601, "num_examples": 87622, "dataset_name": "triviaqa"}, "validation": {"name": "validation", "num_bytes": 163819509, "num_examples": 11313, "dataset_name": "triviaqa"}}, "download_checksums": {"http://eaidata.bmk.sh/data/triviaqa-unfiltered.tar.gz": {"num_bytes": 546481381, "checksum": "adc19b42769062d241a8fbe834c56e58598d9322eb6c614e9f33a68a2cf5523e"}}, "download_size": 546481381, "post_processing_size": null, "dataset_size": 1435213110, "size_in_bytes": 1981694491}}
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Custom TriviaQA because HF version sanitizes the dataset differently.
# https://github.com/huggingface/datasets/blob/9977ade72191ff0b6907ec63935448c6269a91a1/datasets/trivia_qa/trivia_qa.py#L285
"""TriviaQA (Unfiltered Raw) dataset."""
import json
import os
import datasets
_CITATION = """\
@InProceedings{JoshiTriviaQA2017,
author = {Joshi, Mandar and Choi, Eunsol and Weld, Daniel S. and Zettlemoyer, Luke},
title = {TriviaQA: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension},
booktitle = {Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics},
month = {July},
year = {2017},
address = {Vancouver, Canada},
publisher = {Association for Computational Linguistics},
}
"""
_DESCRIPTION = """\
TriviaQA is a reading comprehension dataset containing over 650K question-answer-evidence
triples. TriviaQA includes 95K question-answer pairs authored by trivia enthusiasts
and independently gathered evidence documents, six per question on average, that provide
high quality distant supervision for answering the questions.
"""
_HOMEPAGE = "https://nlp.cs.washington.edu/triviaqa/"
_LICENSE = "Apache License 2.0"
_URLS = "https://nlp.cs.washington.edu/triviaqa/data/triviaqa-unfiltered.tar.gz"
class Triviaqa(datasets.GeneratorBasedBuilder):
"""TriviaQA is a reading comprehension dataset containing over 650K question-answer-evidence triples"""
VERSION = datasets.Version("0.0.2")
BUILDER_CONFIGS = [
datasets.BuilderConfig(
name="triviaqa", version=VERSION, description="The TriviaQA dataset"
),
]
def _info(self):
features = datasets.Features(
{
"question_id": datasets.Value("string"),
"question_source": datasets.Value("string"),
"question": datasets.Value("string"),
"answer": {
"aliases": datasets.features.Sequence(
datasets.Value("string"),
),
"value": datasets.Value("string"),
},
"search_results": datasets.features.Sequence(
{
"description": datasets.Value("string"),
"filename": datasets.Value("string"),
"rank": datasets.Value("int32"),
"title": datasets.Value("string"),
"url": datasets.Value("string"),
"search_context": datasets.Value("string"),
}
),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION,
)
def _split_generators(self, dl_manager):
urls = _URLS
data_dir = dl_manager.download_and_extract(urls)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
# These kwargs will be passed to _generate_examples
gen_kwargs={
"filepath": os.path.join(
data_dir, "triviaqa-unfiltered", "unfiltered-web-train.json"
),
},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
# These kwargs will be passed to _generate_examples
gen_kwargs={
"filepath": os.path.join(
data_dir, "triviaqa-unfiltered", "unfiltered-web-dev.json"
),
},
),
]
# method parameters are unpacked from `gen_kwargs` as given in `_split_generators`
def _generate_examples(self, filepath):
with open(filepath, encoding="utf-8") as f:
json_data = json.load(f)["Data"]
for key, data in enumerate(json_data):
search_results = []
for search_result in data["SearchResults"]:
search_results.append(
{
"description": search_result["Description"]
if "Description" in search_result
else "",
"filename": search_result["Filename"]
if "Filename" in search_result
else "",
"rank": search_result["Rank"]
if "Rank" in search_result
else -1,
"title": search_result["Title"]
if "Title" in search_result
else "",
"url": search_result["Url"]
if "Url" in search_result
else "",
"search_context": search_result["SearchContext"]
if "SearchContext" in search_result
else "",
}
)
yield key, {
"question_id": data["QuestionId"],
"question_source": data["QuestionSource"],
"question": data["Question"],
"answer": {
"aliases": data["Answer"]["Aliases"],
"value": data["Answer"]["Value"],
},
"search_results": search_results,
}
...@@ -10,11 +10,10 @@ high quality distant supervision for answering the questions. ...@@ -10,11 +10,10 @@ high quality distant supervision for answering the questions.
Homepage: https://nlp.cs.washington.edu/triviaqa/ Homepage: https://nlp.cs.washington.edu/triviaqa/
""" """
import inspect import inspect
import lm_eval.datasets.triviaqa.triviaqa import string
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean from lm_eval.metrics import mean
_CITATION = """ _CITATION = """
@InProceedings{JoshiTriviaQA2017, @InProceedings{JoshiTriviaQA2017,
author = {Joshi, Mandar and Choi, Eunsol and Weld, Daniel S. and Zettlemoyer, Luke}, author = {Joshi, Mandar and Choi, Eunsol and Weld, Daniel S. and Zettlemoyer, Luke},
...@@ -29,9 +28,9 @@ _CITATION = """ ...@@ -29,9 +28,9 @@ _CITATION = """
class TriviaQA(Task): class TriviaQA(Task):
VERSION = 1 VERSION = 2
DATASET_PATH = inspect.getfile(lm_eval.datasets.triviaqa.triviaqa) DATASET_PATH = "trivia_qa"
DATASET_NAME = None DATASET_NAME = "rc.nocontext"
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -74,19 +73,27 @@ class TriviaQA(Task): ...@@ -74,19 +73,27 @@ class TriviaQA(Task):
return ret return ret
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ret = [] """Uses RequestFactory to construct Requests and returns an iterable of
for alias in self._remove_prefixes(doc["answer"]["aliases"]): Requests which will be sent to the LM.
_, is_prediction = rf.loglikelihood(ctx, " " + alias) :param doc:
ret.append(is_prediction) The document as returned from training_docs, validation_docs, or test_docs.
return ret :param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
continuation = rf.greedy_until(ctx, {"until": ["\n", ".", ","]})
return continuation
def process_results(self, doc, results): def process_results(self, doc, results):
return {"acc": float(any(results))} continuation = results[0].strip().lower().translate(str.maketrans('', '', string.punctuation))
list_of_candidates = [alias.lower().translate(str.maketrans('', '', string.punctuation)) for alias in self._remove_prefixes(doc["answer"]["aliases"])]
return {"em": float(continuation in list_of_candidates)}
def aggregation(self): def aggregation(self):
return { return {
"acc": mean, "em": mean,
} }
def higher_is_better(self): def higher_is_better(self):
return {"acc": True} return {"em": True}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment