Commit ca9316c9 authored by Baber's avatar Baber
Browse files

add other ruler tasks

parent 92953afa
......@@ -83,7 +83,7 @@ def sys_word_pair_random(
tokenizer=None,
incremental: int = 10,
remove_newline_tab=False,
tokens_to_generate=30,
tokens_to_generate=120,
):
assert tokenizer is not None, "Tokenizer is not provided."
write_jsons = []
......
include: niah_1.yaml
task: ruler_fwe
download_dataset: !function fwe_utils.fwe_download
generation_kwargs:
do_sample: false
temperature: 0.0
max_gen_toks: 50
until: []
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import itertools
import random
import string
from functools import cache
import datasets
import numpy as np
import transformers
from scipy.special import zeta
from tqdm import tqdm
from transformers import AutoTokenizer
from lm_eval.tasks.ruler.utils import SEQ_LENGTHS
config = (
{
"tokens_to_generate": 50,
"template": """Read the following coded text and track the frequency of each coded word. Find the three most frequently appeared coded words. {context}\nQuestion: Do not provide any explanation. Please ignore the dots '....'. What are the three most frequently appeared words in the above coded text?""",
"answer_prefix": """ Answer: According to the coded text above, the three most frequently appeared words are:""",
},
)
SEED = 42
TEMPLATE = "Read the following coded text and track the frequency of each coded word. Find the three most frequently appeared coded words. {context}\nQuestion: Do not provide any explanation. Please ignore the dots '....'. What are the three most frequently appeared words in the above coded text?\n\n"
def generate_input_output(
max_len: int,
tokenizer: "transformers.PreTrainedTokenizerFast",
num_words=-1,
coded_wordlen=6,
vocab_size=2000,
incremental=10,
alpha=2.0,
) -> tuple[str, list[str], int]:
# generate vocab
vocab = [
"".join(random.choices(string.ascii_lowercase, k=coded_wordlen))
for _ in range(vocab_size)
]
while len(set(vocab)) < vocab_size:
vocab.append("".join(random.choices(string.ascii_lowercase, k=coded_wordlen)))
vocab = sorted(list(set(vocab)))
random.Random(SEED).shuffle(vocab)
vocab[0] = "..." # treat the top ranked as noise
# sample words
template = TEMPLATE
def gen_text(num_words):
k = np.arange(1, len(vocab) + 1)
sampled_cnt = num_words * (k**-alpha) / zeta(alpha)
sampled_words = [[w] * zi for w, zi in zip(vocab, sampled_cnt.astype(int))]
sampled_words = [x for wlst in sampled_words for x in wlst]
random.Random(SEED).shuffle(sampled_words)
return template.format(context=" ".join(sampled_words), query=""), vocab[1:4]
if num_words > 0:
num_words = num_words
text, answer = gen_text(num_words)
while len(tokenizer(text).input_ids) > max_len:
num_words -= incremental
text, answer = gen_text(num_words)
else:
num_words = max_len // coded_wordlen # init
text, answer = gen_text(num_words)
while len(tokenizer(text).input_ids) < max_len:
num_words += incremental
text, answer = gen_text(num_words)
num_words -= incremental
text, answer = gen_text(num_words)
return text, answer, num_words
def sys_kwext(
tokenizer: "transformers.PreTrainedTokenizerFast",
max_seq_length: int,
num_samples: int = 500,
vocab_size: int = -1,
coded_wordlen: int = 6,
alpha: float = 2.0,
tokens_to_generate: int = 50,
remove_newline_tab: bool = False,
) -> list[dict]:
write_jsons = []
tokens_to_generate = tokens_to_generate
vocab_size = max_seq_length // 50 if vocab_size == -1 else vocab_size
# get number of words
input_max_len = max_seq_length
_, _, num_example_words = generate_input_output(
input_max_len,
tokenizer,
coded_wordlen=coded_wordlen,
vocab_size=vocab_size,
incremental=input_max_len // 32,
alpha=alpha,
)
print("num_example_words:", num_example_words)
# Generate samples
for index in tqdm(range(num_samples)):
# construct input
input_max_len = max_seq_length
input_text, answer, _ = generate_input_output(
input_max_len,
tokenizer,
num_words=num_example_words,
coded_wordlen=coded_wordlen,
vocab_size=vocab_size,
incremental=input_max_len // 32,
alpha=alpha,
)
length = len(tokenizer(input_text).input_ids) + tokens_to_generate
if remove_newline_tab:
input_text = " ".join(
input_text.replace("\n", " ").replace("\t", " ").strip().split()
)
formatted_output = {
"index": index,
"input": input_text,
"outputs": answer,
"length": length,
"max_seq_length": max_seq_length,
"gen_prefix": "Answer: According to the coded text above, the three most frequently appeared words are:",
}
write_jsons.append(formatted_output)
return write_jsons
@cache
def get_tokenizer(pretrained):
return AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
def get_dataset(pretrained, max_seq_length=None, **kwargs):
tokenizer = get_tokenizer(pretrained)
write_jsons = sys_kwext(
tokenizer=tokenizer,
max_seq_length=max_seq_length,
)
return write_jsons
def fwe_download(**kwargs):
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = (get_dataset(pretrained, max_seq_length=seq) for seq in SEQ_LENGTHS)
return {
"test": datasets.Dataset.from_list(
list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST
)
}
include: niah_1.yaml
task: ruler_qa
dataset_path: rajpurkar/squad_v2
dataset_name: main
process_docs: !function qa_utils.process_squad_docs
test_split: validation
download_dataset: ""
download_dataset: !function qa_utils.get_qa_dataset
test_split: test
generation_kwargs:
do_sample: false
temperature: 0.0
max_gen_toks: 32
until: []
# import random
# from collections import defaultdict
#
# from datasets import tqdm
#
#
# # def process_qa_docs(df: datasets.Dataset) -> datasets.Dataset:
# # df = df.rename_column("question", "query")
# # df = df.rename_column("answers", "outputs")
# # df = df.map(lambda x: {"outputs": x["outputs"].get("text", [])})
# # return df
#
#
# def process_qa_docs(dataset):
# contexts = sorted(list(set(dataset["context"])))
# ctx_to_id = {c: i for i, c in enumerate(contexts)}
#
# title_contexts = defaultdict(list)
# for ex in dataset:
# title_contexts[ex["title"]].append(ctx_to_id[ex["context"]])
#
# # Process function for map
# def process_example(example):
# if len(example["answers"]["text"]) == 0:
# return None
#
# ctx_id = ctx_to_id[example["context"]]
# return {
# "query": example["question"],
# "outputs": example["answers"]["text"],
# "context": [ctx_id],
# "more_context": [
# i for i in title_contexts[example["title"]] if i != ctx_id
# ],
# }
#
# return dataset.map(
# process_example,
# remove_columns=dataset.column_names,
# filter_fn=lambda x: x is not None,
# )
#
#
# def generate_input_output(index, num_docs, dataset, random_seed=42):
# # Get current example
# example = dataset[index]
# curr_q = example["query"]
# curr_a = example["outputs"]
# curr_docs = example["context"]
# curr_more = example["more_context"]
# all_contexts = example["all_contexts"]
#
# if num_docs < len(all_contexts):
# if (num_docs - len(curr_docs)) > len(curr_more):
# # Need additional random docs
# addition_docs = [
# i for i in range(len(all_contexts)) if i not in curr_docs + curr_more
# ]
# all_docs = (
# curr_docs
# + curr_more
# + random.sample(
# addition_docs, max(0, num_docs - len(curr_docs) - len(curr_more))
# )
# )
# else:
# # Can just use some of the more_context
# all_docs = curr_docs + random.sample(curr_more, num_docs - len(curr_docs))
#
# all_docs = [all_contexts[idx] for idx in all_docs]
# else:
# all_docs = all_contexts
#
# random.Random(random_seed).shuffle(all_docs)
#
# # Assuming DOCUMENT_PROMPT is something like "Document {i}: {document}"
# DOCUMENT_PROMPT = "Document {i}: {document}"
# context = "\n\n".join(
# [DOCUMENT_PROMPT.format(i=i + 1, document=d) for i, d in enumerate(all_docs)]
# )
#
# # Assuming template is something like "{context}\nQuestion: {query}"
# template = "{context}\nQuestion: {query}"
# input_text = template.format(context=context, query=curr_q)
#
# return input_text, curr_a
#
#
# def get_total_docs(example):
# return len(example["all_contexts"]) if "all_contexts" in example else 0
#
#
# def generate_samples(
# num_samples: int, max_seq_length: int, save_dir: str, incremental: int = 10
# ):
# write_jsons = []
# tokens_to_generate = tokens_to_generate
#
# # Find the perfect num_docs
# num_docs = incremental
#
# total_tokens = 0 # Track the total tokens generated for this example
# while total_tokens + tokens_to_generate < max_seq_length:
# input_text, answer = generate_input_output(0, num_docs)
# # Calculate the number of tokens in the example
# total_tokens = len(TOKENIZER(input_text + f" {answer}").input_ids)
# print(
# f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Docs: {num_docs}"
# )
# if total_tokens + tokens_to_generate > max_seq_length:
# num_docs -= incremental
# break
#
# num_docs += incremental
# if num_docs > len(DOCS):
# num_docs = len(DOCS)
# break
# print("Number of documents:", num_docs)
#
# # Generate samples
# for index in tqdm(range(num_samples)):
# used_docs = num_docs
# while True:
# try:
# input_text, answer = generate_input_output(
# index + pre_samples, used_docs
# )
# length = len(TOKENIZER(input_text).input_ids) + tokens_to_generate
# assert length <= max_seq_length, f"{length} exceeds max_seq_length."
# break
# except:
# if used_docs > incremental:
# used_docs -= incremental
#
# if remove_newline_tab:
# input_text = " ".join(
# input_text.replace("\n", " ").replace("\t", " ").strip().split()
# )
#
# formatted_output = {
# "index": index,
# "input": input_text,
# "outputs": answer,
# "length": length,
# }
# write_jsons.append(formatted_output)
#
# return write_jsons
import itertools # noqa: I001
import random
from functools import cache
import datasets
import requests
from tqdm import tqdm
from transformers import AutoTokenizer
from lm_eval.tasks.ruler.utils import SEQ_LENGTHS
# config = {
# 'tokens_to_generate': 32,
# 'template': """Answer the question based on the given documents. Only give me the answer and do not output any other words.\n\nThe following are given documents.\n\n{context}\n\nAnswer the question based on the given documents. Only give me the answer and do not output any other words.\n\nQuestion: {query}""",
# 'answer_prefix': """ Answer:""",
# }
SEED = 42
TEMPLATE = """Answer the question based on the given documents. Only give me the answer and do not output any other words.\n\nThe following are given documents.\n\n{context}\n\nAnswer the question based on the given documents. Only give me the answer and do not output any other words.\n\nQuestion: {query}"""
DOCUMENT_PROMPT = "Document {i}:\n{document}"
@cache
def download_json(url):
response = requests.get(url)
response.raise_for_status()
data = response.json()
return data
def read_squad(url="https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json"):
data = download_json(url)
total_docs = [p["context"] for d in data["data"] for p in d["paragraphs"]]
total_docs = sorted(list(set(total_docs)))
total_docs_dict = {c: idx for idx, c in enumerate(total_docs)}
total_qas = []
for d in data["data"]:
more_docs = [total_docs_dict[p["context"]] for p in d["paragraphs"]]
for p in d["paragraphs"]:
for qas in p["qas"]:
if not qas["is_impossible"]:
total_qas.append(
{
"query": qas["question"],
"outputs": [a["text"] for a in qas["answers"]],
"context": [total_docs_dict[p["context"]]],
"more_context": [
idx
for idx in more_docs
if idx != total_docs_dict[p["context"]]
],
}
)
return total_qas, total_docs
def generate_input_output(
index: int, num_docs: int, qas: list[dict], docs: list[str]
) -> tuple[str, list[str]]:
curr_q: str = qas[index]["query"]
curr_a: list[str] = qas[index]["outputs"]
curr_docs: list[int] = qas[index]["context"]
curr_more: list[int] = qas[index].get("more_context", [])
if num_docs < len(docs):
if (num_docs - len(curr_docs)) > len(curr_more):
addition_docs = [
i for i, d in enumerate(docs) if i not in curr_docs + curr_more
]
all_docs = (
curr_docs
+ curr_more
+ random.sample(
addition_docs, max(0, num_docs - len(curr_docs) - len(curr_more))
)
)
else:
all_docs = curr_docs + random.sample(curr_more, num_docs - len(curr_docs))
all_docs = [docs[idx] for idx in all_docs]
else:
all_docs = docs
random.Random(SEED).shuffle(all_docs)
context = "\n\n".join(
[DOCUMENT_PROMPT.format(i=i + 1, document=d) for i, d in enumerate(all_docs)]
)
input_text = TEMPLATE.format(context=context, query=curr_q)
return input_text, curr_a
def generate_samples(
tokenizer,
docs: list[str],
qas: list[dict],
max_seq_length: int,
num_samples: int = 500,
tokens_to_generate: int = 32,
pre_samples: int = 0,
incremental: int = 10,
remove_newline_tab=False,
) -> list[dict]:
write_jsons = []
tokens_to_generate = tokens_to_generate
# Find the perfect num_docs
num_docs = incremental
total_tokens = 0 # Track the total tokens generated for this example
while total_tokens + tokens_to_generate < max_seq_length:
input_text, answer = generate_input_output(0, num_docs, qas=qas, docs=docs)
# Calculate the number of tokens in the example
total_tokens = len(tokenizer(input_text + f" {answer}").input_ids)
print(
f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Docs: {num_docs}"
)
if total_tokens + tokens_to_generate > max_seq_length:
num_docs -= incremental
break
num_docs += incremental
if num_docs > len(docs):
num_docs = len(docs)
break
print("Number of documents:", num_docs)
# Generate samples
for index in tqdm(range(num_samples)):
used_docs = num_docs
while True:
try:
input_text, answer = generate_input_output(
index + pre_samples, used_docs, qas=qas, docs=docs
)
length = len(tokenizer(input_text).input_ids) + tokens_to_generate
assert length <= max_seq_length, f"{length} exceeds max_seq_length."
break
except: # noqa: E722
if used_docs > incremental:
used_docs -= incremental
if remove_newline_tab:
input_text = " ".join(
input_text.replace("\n", " ").replace("\t", " ").strip().split()
)
formatted_output = {
"index": index,
"input": input_text,
"outputs": answer,
"length": length,
"max_seq_length": max_seq_length,
"gen_prefix": "Answer:",
}
write_jsons.append(formatted_output)
return write_jsons
@cache
def get_tokenizer(pretrained):
return AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
def get_dataset(pretrained, docs, qas, max_seq_length=None, **kwargs):
tokenizer = get_tokenizer(pretrained)
write_jsons = generate_samples(
tokenizer=tokenizer,
docs=docs,
qas=qas,
num_samples=500,
tokens_to_generate=128,
max_seq_length=max_seq_length,
)
return write_jsons
def get_qa_dataset(**kwargs):
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
qas, docs = read_squad()
df = (
get_dataset(pretrained, docs=docs, qas=qas, max_seq_length=seq)
for seq in SEQ_LENGTHS
)
return {
"test": datasets.Dataset.from_list(
list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST
)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment