Commit 2e43c8d7 authored by Baber's avatar Baber
Browse files

add qa_hotpot

parent c8a3f2e2
include: qa_squad.yaml
task: ruler_qa_hotpot
download_dataset: !function qa_utils.get_hotpotqa
include: niah_1.yaml
task: ruler_qa
download_dataset: !function qa_utils.get_qa_dataset
task: ruler_qa_squad
download_dataset: !function qa_utils.get_squad
test_split: test
generation_kwargs:
do_sample: false
......
import itertools # noqa: I001
import random
from functools import cache
from functools import cache, partial
import datasets
import requests
......@@ -27,6 +27,7 @@ def download_json(url):
return data
@cache
def read_squad(url="https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json"):
data = download_json(url)
total_docs = [p["context"] for d in data["data"] for p in d["paragraphs"]]
......@@ -55,6 +56,30 @@ def read_squad(url="https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.
return total_qas, total_docs
@cache
def read_hotpotqa(
url="http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_distractor_v1.json",
):
data = download_json(url)
total_docs = [f"{t}\n{''.join(p)}" for d in data for t, p in d["context"]]
total_docs = sorted(list(set(total_docs)))
total_docs_dict = {c: idx for idx, c in enumerate(total_docs)}
total_qas = []
for d in data:
total_qas.append(
{
"query": d["question"],
"outputs": [d["answer"]],
"context": [
total_docs_dict[f"{t}\n{''.join(p)}"] for t, p in d["context"]
],
}
)
return total_qas, total_docs
def generate_input_output(
index: int, num_docs: int, qas: list[dict], docs: list[str]
) -> tuple[str, list[str]]:
......@@ -176,10 +201,13 @@ def get_dataset(pretrained, docs, qas, max_seq_length=None, **kwargs):
return write_jsons
def get_qa_dataset(**kwargs):
def get_qa_dataset(ds, **kwargs):
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
qas, docs = read_squad()
if ds == "squad":
qas, docs = read_squad()
else:
qas, docs = read_hotpotqa()
df = (
get_dataset(pretrained, docs=docs, qas=qas, max_seq_length=seq)
for seq in SEQ_LENGTHS
......@@ -190,3 +218,7 @@ def get_qa_dataset(**kwargs):
list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST
)
}
get_squad = partial(get_qa_dataset, "squad")
get_hotpotqa = partial(get_qa_dataset, "hotpotqa")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment