Commit 537a77fd authored by Baber's avatar Baber
Browse files

nit

parent 2e43c8d7
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import itertools # noqa: I001 import itertools # noqa: I001
import random import random
from functools import cache, partial from functools import cache, partial
...@@ -9,18 +24,18 @@ from transformers import AutoTokenizer ...@@ -9,18 +24,18 @@ from transformers import AutoTokenizer
from lm_eval.tasks.ruler.utils import SEQ_LENGTHS from lm_eval.tasks.ruler.utils import SEQ_LENGTHS
# config = { config = {
# 'tokens_to_generate': 32, "tokens_to_generate": 32,
# 'template': """Answer the question based on the given documents. Only give me the answer and do not output any other words.\n\nThe following are given documents.\n\n{context}\n\nAnswer the question based on the given documents. Only give me the answer and do not output any other words.\n\nQuestion: {query}""", "template": """Answer the question based on the given documents. Only give me the answer and do not output any other words.\n\nThe following are given documents.\n\n{context}\n\nAnswer the question based on the given documents. Only give me the answer and do not output any other words.\n\nQuestion: {query}""",
# 'answer_prefix': """ Answer:""", "answer_prefix": """ Answer:""",
# } }
SEED = 42 SEED = 42
TEMPLATE = """Answer the question based on the given documents. Only give me the answer and do not output any other words.\n\nThe following are given documents.\n\n{context}\n\nAnswer the question based on the given documents. Only give me the answer and do not output any other words.\n\nQuestion: {query}""" TEMPLATE = """Answer the question based on the given documents. Only give me the answer and do not output any other words.\n\nThe following are given documents.\n\n{context}\n\nAnswer the question based on the given documents. Only give me the answer and do not output any other words.\n\nQuestion: {query}"""
DOCUMENT_PROMPT = "Document {i}:\n{document}" DOCUMENT_PROMPT = "Document {i}:\n{document}"
@cache @cache
def download_json(url): def download_json(url) -> dict:
response = requests.get(url) response = requests.get(url)
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
...@@ -28,7 +43,9 @@ def download_json(url): ...@@ -28,7 +43,9 @@ def download_json(url):
@cache @cache
def read_squad(url="https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json"): def read_squad(
url="https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json",
) -> tuple[list[dict], list[str]]:
data = download_json(url) data = download_json(url)
total_docs = [p["context"] for d in data["data"] for p in d["paragraphs"]] total_docs = [p["context"] for d in data["data"] for p in d["paragraphs"]]
total_docs = sorted(list(set(total_docs))) total_docs = sorted(list(set(total_docs)))
...@@ -59,7 +76,7 @@ def read_squad(url="https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0. ...@@ -59,7 +76,7 @@ def read_squad(url="https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.
@cache @cache
def read_hotpotqa( def read_hotpotqa(
url="http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_distractor_v1.json", url="http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_distractor_v1.json",
): ) -> tuple[list[dict], list[str]]:
data = download_json(url) data = download_json(url)
total_docs = [f"{t}\n{''.join(p)}" for d in data for t, p in d["context"]] total_docs = [f"{t}\n{''.join(p)}" for d in data for t, p in d["context"]]
total_docs = sorted(list(set(total_docs))) total_docs = sorted(list(set(total_docs)))
...@@ -195,7 +212,7 @@ def get_dataset(pretrained, docs, qas, max_seq_length=None, **kwargs): ...@@ -195,7 +212,7 @@ def get_dataset(pretrained, docs, qas, max_seq_length=None, **kwargs):
docs=docs, docs=docs,
qas=qas, qas=qas,
num_samples=500, num_samples=500,
tokens_to_generate=128, tokens_to_generate=32,
max_seq_length=max_seq_length, max_seq_length=max_seq_length,
) )
return write_jsons return write_jsons
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment