Commit 2e43c8d7 authored by Baber's avatar Baber
Browse files

add qa_hotpot

parent c8a3f2e2
include: qa_squad.yaml
task: ruler_qa_hotpot
download_dataset: !function qa_utils.get_hotpotqa
include: niah_1.yaml include: niah_1.yaml
task: ruler_qa task: ruler_qa_squad
download_dataset: !function qa_utils.get_qa_dataset download_dataset: !function qa_utils.get_squad
test_split: test test_split: test
generation_kwargs: generation_kwargs:
do_sample: false do_sample: false
......
import itertools # noqa: I001 import itertools # noqa: I001
import random import random
from functools import cache from functools import cache, partial
import datasets import datasets
import requests import requests
...@@ -27,6 +27,7 @@ def download_json(url): ...@@ -27,6 +27,7 @@ def download_json(url):
return data return data
@cache
def read_squad(url="https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json"): def read_squad(url="https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json"):
data = download_json(url) data = download_json(url)
total_docs = [p["context"] for d in data["data"] for p in d["paragraphs"]] total_docs = [p["context"] for d in data["data"] for p in d["paragraphs"]]
...@@ -55,6 +56,30 @@ def read_squad(url="https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0. ...@@ -55,6 +56,30 @@ def read_squad(url="https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.
return total_qas, total_docs return total_qas, total_docs
@cache
def read_hotpotqa(
url="http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_distractor_v1.json",
):
data = download_json(url)
total_docs = [f"{t}\n{''.join(p)}" for d in data for t, p in d["context"]]
total_docs = sorted(list(set(total_docs)))
total_docs_dict = {c: idx for idx, c in enumerate(total_docs)}
total_qas = []
for d in data:
total_qas.append(
{
"query": d["question"],
"outputs": [d["answer"]],
"context": [
total_docs_dict[f"{t}\n{''.join(p)}"] for t, p in d["context"]
],
}
)
return total_qas, total_docs
def generate_input_output( def generate_input_output(
index: int, num_docs: int, qas: list[dict], docs: list[str] index: int, num_docs: int, qas: list[dict], docs: list[str]
) -> tuple[str, list[str]]: ) -> tuple[str, list[str]]:
...@@ -176,10 +201,13 @@ def get_dataset(pretrained, docs, qas, max_seq_length=None, **kwargs): ...@@ -176,10 +201,13 @@ def get_dataset(pretrained, docs, qas, max_seq_length=None, **kwargs):
return write_jsons return write_jsons
def get_qa_dataset(**kwargs): def get_qa_dataset(ds, **kwargs):
kwargs = kwargs.get("metadata", {}) kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {})) pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
qas, docs = read_squad() if ds == "squad":
qas, docs = read_squad()
else:
qas, docs = read_hotpotqa()
df = ( df = (
get_dataset(pretrained, docs=docs, qas=qas, max_seq_length=seq) get_dataset(pretrained, docs=docs, qas=qas, max_seq_length=seq)
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
...@@ -190,3 +218,7 @@ def get_qa_dataset(**kwargs): ...@@ -190,3 +218,7 @@ def get_qa_dataset(**kwargs):
list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST
) )
} }
get_squad = partial(get_qa_dataset, "squad")
get_hotpotqa = partial(get_qa_dataset, "hotpotqa")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment