"include/vscode:/vscode.git/clone" did not exist on "5d7fb2a1059bddca71a55e9d5f9d3e9143d20f37"
Commit 537a77fd authored by Baber's avatar Baber
Browse files

nit

parent 2e43c8d7
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import itertools # noqa: I001
import random
from functools import cache, partial
......@@ -9,18 +24,18 @@ from transformers import AutoTokenizer
from lm_eval.tasks.ruler.utils import SEQ_LENGTHS
# config = {
# 'tokens_to_generate': 32,
# 'template': """Answer the question based on the given documents. Only give me the answer and do not output any other words.\n\nThe following are given documents.\n\n{context}\n\nAnswer the question based on the given documents. Only give me the answer and do not output any other words.\n\nQuestion: {query}""",
# 'answer_prefix': """ Answer:""",
# }
config = {
"tokens_to_generate": 32,
"template": """Answer the question based on the given documents. Only give me the answer and do not output any other words.\n\nThe following are given documents.\n\n{context}\n\nAnswer the question based on the given documents. Only give me the answer and do not output any other words.\n\nQuestion: {query}""",
"answer_prefix": """ Answer:""",
}
SEED = 42
TEMPLATE = """Answer the question based on the given documents. Only give me the answer and do not output any other words.\n\nThe following are given documents.\n\n{context}\n\nAnswer the question based on the given documents. Only give me the answer and do not output any other words.\n\nQuestion: {query}"""
DOCUMENT_PROMPT = "Document {i}:\n{document}"
@cache
def download_json(url):
def download_json(url) -> dict:
response = requests.get(url)
response.raise_for_status()
data = response.json()
......@@ -28,7 +43,9 @@ def download_json(url):
@cache
def read_squad(url="https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json"):
def read_squad(
url="https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json",
) -> tuple[list[dict], list[str]]:
data = download_json(url)
total_docs = [p["context"] for d in data["data"] for p in d["paragraphs"]]
total_docs = sorted(list(set(total_docs)))
......@@ -59,7 +76,7 @@ def read_squad(url="https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.
@cache
def read_hotpotqa(
url="http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_distractor_v1.json",
):
) -> tuple[list[dict], list[str]]:
data = download_json(url)
total_docs = [f"{t}\n{''.join(p)}" for d in data for t, p in d["context"]]
total_docs = sorted(list(set(total_docs)))
......@@ -195,7 +212,7 @@ def get_dataset(pretrained, docs, qas, max_seq_length=None, **kwargs):
docs=docs,
qas=qas,
num_samples=500,
tokens_to_generate=128,
tokens_to_generate=32,
max_seq_length=max_seq_length,
)
return write_jsons
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment