Commit ca9316c9 authored by Baber's avatar Baber
Browse files

add other ruler tasks

parent 92953afa
...@@ -83,7 +83,7 @@ def sys_word_pair_random( ...@@ -83,7 +83,7 @@ def sys_word_pair_random(
tokenizer=None, tokenizer=None,
incremental: int = 10, incremental: int = 10,
remove_newline_tab=False, remove_newline_tab=False,
tokens_to_generate=30, tokens_to_generate=120,
): ):
assert tokenizer is not None, "Tokenizer is not provided." assert tokenizer is not None, "Tokenizer is not provided."
write_jsons = [] write_jsons = []
......
include: niah_1.yaml
task: ruler_fwe
download_dataset: !function fwe_utils.fwe_download
generation_kwargs:
do_sample: false
temperature: 0.0
max_gen_toks: 50
until: []
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import itertools
import random
import string
from functools import cache
import datasets
import numpy as np
import transformers
from scipy.special import zeta
from tqdm import tqdm
from transformers import AutoTokenizer
from lm_eval.tasks.ruler.utils import SEQ_LENGTHS
config = (
{
"tokens_to_generate": 50,
"template": """Read the following coded text and track the frequency of each coded word. Find the three most frequently appeared coded words. {context}\nQuestion: Do not provide any explanation. Please ignore the dots '....'. What are the three most frequently appeared words in the above coded text?""",
"answer_prefix": """ Answer: According to the coded text above, the three most frequently appeared words are:""",
},
)
SEED = 42
TEMPLATE = "Read the following coded text and track the frequency of each coded word. Find the three most frequently appeared coded words. {context}\nQuestion: Do not provide any explanation. Please ignore the dots '....'. What are the three most frequently appeared words in the above coded text?\n\n"
def generate_input_output(
max_len: int,
tokenizer: "transformers.PreTrainedTokenizerFast",
num_words=-1,
coded_wordlen=6,
vocab_size=2000,
incremental=10,
alpha=2.0,
) -> tuple[str, list[str], int]:
# generate vocab
vocab = [
"".join(random.choices(string.ascii_lowercase, k=coded_wordlen))
for _ in range(vocab_size)
]
while len(set(vocab)) < vocab_size:
vocab.append("".join(random.choices(string.ascii_lowercase, k=coded_wordlen)))
vocab = sorted(list(set(vocab)))
random.Random(SEED).shuffle(vocab)
vocab[0] = "..." # treat the top ranked as noise
# sample words
template = TEMPLATE
def gen_text(num_words):
k = np.arange(1, len(vocab) + 1)
sampled_cnt = num_words * (k**-alpha) / zeta(alpha)
sampled_words = [[w] * zi for w, zi in zip(vocab, sampled_cnt.astype(int))]
sampled_words = [x for wlst in sampled_words for x in wlst]
random.Random(SEED).shuffle(sampled_words)
return template.format(context=" ".join(sampled_words), query=""), vocab[1:4]
if num_words > 0:
num_words = num_words
text, answer = gen_text(num_words)
while len(tokenizer(text).input_ids) > max_len:
num_words -= incremental
text, answer = gen_text(num_words)
else:
num_words = max_len // coded_wordlen # init
text, answer = gen_text(num_words)
while len(tokenizer(text).input_ids) < max_len:
num_words += incremental
text, answer = gen_text(num_words)
num_words -= incremental
text, answer = gen_text(num_words)
return text, answer, num_words
def sys_kwext(
tokenizer: "transformers.PreTrainedTokenizerFast",
max_seq_length: int,
num_samples: int = 500,
vocab_size: int = -1,
coded_wordlen: int = 6,
alpha: float = 2.0,
tokens_to_generate: int = 50,
remove_newline_tab: bool = False,
) -> list[dict]:
write_jsons = []
tokens_to_generate = tokens_to_generate
vocab_size = max_seq_length // 50 if vocab_size == -1 else vocab_size
# get number of words
input_max_len = max_seq_length
_, _, num_example_words = generate_input_output(
input_max_len,
tokenizer,
coded_wordlen=coded_wordlen,
vocab_size=vocab_size,
incremental=input_max_len // 32,
alpha=alpha,
)
print("num_example_words:", num_example_words)
# Generate samples
for index in tqdm(range(num_samples)):
# construct input
input_max_len = max_seq_length
input_text, answer, _ = generate_input_output(
input_max_len,
tokenizer,
num_words=num_example_words,
coded_wordlen=coded_wordlen,
vocab_size=vocab_size,
incremental=input_max_len // 32,
alpha=alpha,
)
length = len(tokenizer(input_text).input_ids) + tokens_to_generate
if remove_newline_tab:
input_text = " ".join(
input_text.replace("\n", " ").replace("\t", " ").strip().split()
)
formatted_output = {
"index": index,
"input": input_text,
"outputs": answer,
"length": length,
"max_seq_length": max_seq_length,
"gen_prefix": "Answer: According to the coded text above, the three most frequently appeared words are:",
}
write_jsons.append(formatted_output)
return write_jsons
@cache
def get_tokenizer(pretrained):
return AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
def get_dataset(pretrained, max_seq_length=None, **kwargs):
tokenizer = get_tokenizer(pretrained)
write_jsons = sys_kwext(
tokenizer=tokenizer,
max_seq_length=max_seq_length,
)
return write_jsons
def fwe_download(**kwargs):
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = (get_dataset(pretrained, max_seq_length=seq) for seq in SEQ_LENGTHS)
return {
"test": datasets.Dataset.from_list(
list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST
)
}
include: niah_1.yaml include: niah_1.yaml
task: ruler_qa task: ruler_qa
dataset_path: rajpurkar/squad_v2 download_dataset: !function qa_utils.get_qa_dataset
dataset_name: main test_split: test
process_docs: !function qa_utils.process_squad_docs generation_kwargs:
test_split: validation do_sample: false
download_dataset: "" temperature: 0.0
max_gen_toks: 32
until: []
# import random import itertools # noqa: I001
# from collections import defaultdict import random
# from functools import cache
# from datasets import tqdm
# import datasets
# import requests
# # def process_qa_docs(df: datasets.Dataset) -> datasets.Dataset: from tqdm import tqdm
# # df = df.rename_column("question", "query") from transformers import AutoTokenizer
# # df = df.rename_column("answers", "outputs")
# # df = df.map(lambda x: {"outputs": x["outputs"].get("text", [])}) from lm_eval.tasks.ruler.utils import SEQ_LENGTHS
# # return df
# # config = {
# # 'tokens_to_generate': 32,
# def process_qa_docs(dataset): # 'template': """Answer the question based on the given documents. Only give me the answer and do not output any other words.\n\nThe following are given documents.\n\n{context}\n\nAnswer the question based on the given documents. Only give me the answer and do not output any other words.\n\nQuestion: {query}""",
# contexts = sorted(list(set(dataset["context"]))) # 'answer_prefix': """ Answer:""",
# ctx_to_id = {c: i for i, c in enumerate(contexts)}
#
# title_contexts = defaultdict(list)
# for ex in dataset:
# title_contexts[ex["title"]].append(ctx_to_id[ex["context"]])
#
# # Process function for map
# def process_example(example):
# if len(example["answers"]["text"]) == 0:
# return None
#
# ctx_id = ctx_to_id[example["context"]]
# return {
# "query": example["question"],
# "outputs": example["answers"]["text"],
# "context": [ctx_id],
# "more_context": [
# i for i in title_contexts[example["title"]] if i != ctx_id
# ],
# } # }
# SEED = 42
# return dataset.map( TEMPLATE = """Answer the question based on the given documents. Only give me the answer and do not output any other words.\n\nThe following are given documents.\n\n{context}\n\nAnswer the question based on the given documents. Only give me the answer and do not output any other words.\n\nQuestion: {query}"""
# process_example, DOCUMENT_PROMPT = "Document {i}:\n{document}"
# remove_columns=dataset.column_names,
# filter_fn=lambda x: x is not None,
# ) @cache
# def download_json(url):
# response = requests.get(url)
# def generate_input_output(index, num_docs, dataset, random_seed=42): response.raise_for_status()
# # Get current example data = response.json()
# example = dataset[index] return data
# curr_q = example["query"]
# curr_a = example["outputs"]
# curr_docs = example["context"] def read_squad(url="https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json"):
# curr_more = example["more_context"] data = download_json(url)
# all_contexts = example["all_contexts"] total_docs = [p["context"] for d in data["data"] for p in d["paragraphs"]]
# total_docs = sorted(list(set(total_docs)))
# if num_docs < len(all_contexts): total_docs_dict = {c: idx for idx, c in enumerate(total_docs)}
# if (num_docs - len(curr_docs)) > len(curr_more):
# # Need additional random docs total_qas = []
# addition_docs = [ for d in data["data"]:
# i for i in range(len(all_contexts)) if i not in curr_docs + curr_more more_docs = [total_docs_dict[p["context"]] for p in d["paragraphs"]]
# ] for p in d["paragraphs"]:
# all_docs = ( for qas in p["qas"]:
# curr_docs if not qas["is_impossible"]:
# + curr_more total_qas.append(
# + random.sample( {
# addition_docs, max(0, num_docs - len(curr_docs) - len(curr_more)) "query": qas["question"],
# ) "outputs": [a["text"] for a in qas["answers"]],
# ) "context": [total_docs_dict[p["context"]]],
# else: "more_context": [
# # Can just use some of the more_context idx
# all_docs = curr_docs + random.sample(curr_more, num_docs - len(curr_docs)) for idx in more_docs
# if idx != total_docs_dict[p["context"]]
# all_docs = [all_contexts[idx] for idx in all_docs] ],
# else: }
# all_docs = all_contexts )
#
# random.Random(random_seed).shuffle(all_docs) return total_qas, total_docs
#
# # Assuming DOCUMENT_PROMPT is something like "Document {i}: {document}"
# DOCUMENT_PROMPT = "Document {i}: {document}" def generate_input_output(
# context = "\n\n".join( index: int, num_docs: int, qas: list[dict], docs: list[str]
# [DOCUMENT_PROMPT.format(i=i + 1, document=d) for i, d in enumerate(all_docs)] ) -> tuple[str, list[str]]:
# ) curr_q: str = qas[index]["query"]
# curr_a: list[str] = qas[index]["outputs"]
# # Assuming template is something like "{context}\nQuestion: {query}" curr_docs: list[int] = qas[index]["context"]
# template = "{context}\nQuestion: {query}" curr_more: list[int] = qas[index].get("more_context", [])
# input_text = template.format(context=context, query=curr_q) if num_docs < len(docs):
# if (num_docs - len(curr_docs)) > len(curr_more):
# return input_text, curr_a addition_docs = [
# i for i, d in enumerate(docs) if i not in curr_docs + curr_more
# ]
# def get_total_docs(example): all_docs = (
# return len(example["all_contexts"]) if "all_contexts" in example else 0 curr_docs
# + curr_more
# + random.sample(
# def generate_samples( addition_docs, max(0, num_docs - len(curr_docs) - len(curr_more))
# num_samples: int, max_seq_length: int, save_dir: str, incremental: int = 10 )
# ): )
# write_jsons = [] else:
# tokens_to_generate = tokens_to_generate all_docs = curr_docs + random.sample(curr_more, num_docs - len(curr_docs))
#
# # Find the perfect num_docs all_docs = [docs[idx] for idx in all_docs]
# num_docs = incremental else:
# all_docs = docs
# total_tokens = 0 # Track the total tokens generated for this example
# while total_tokens + tokens_to_generate < max_seq_length: random.Random(SEED).shuffle(all_docs)
# input_text, answer = generate_input_output(0, num_docs)
# # Calculate the number of tokens in the example context = "\n\n".join(
# total_tokens = len(TOKENIZER(input_text + f" {answer}").input_ids) [DOCUMENT_PROMPT.format(i=i + 1, document=d) for i, d in enumerate(all_docs)]
# print( )
# f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Docs: {num_docs}" input_text = TEMPLATE.format(context=context, query=curr_q)
# ) return input_text, curr_a
# if total_tokens + tokens_to_generate > max_seq_length:
# num_docs -= incremental
# break def generate_samples(
# tokenizer,
# num_docs += incremental docs: list[str],
# if num_docs > len(DOCS): qas: list[dict],
# num_docs = len(DOCS) max_seq_length: int,
# break num_samples: int = 500,
# print("Number of documents:", num_docs) tokens_to_generate: int = 32,
# pre_samples: int = 0,
# # Generate samples incremental: int = 10,
# for index in tqdm(range(num_samples)): remove_newline_tab=False,
# used_docs = num_docs ) -> list[dict]:
# while True: write_jsons = []
# try: tokens_to_generate = tokens_to_generate
# input_text, answer = generate_input_output(
# index + pre_samples, used_docs # Find the perfect num_docs
# ) num_docs = incremental
# length = len(TOKENIZER(input_text).input_ids) + tokens_to_generate
# assert length <= max_seq_length, f"{length} exceeds max_seq_length." total_tokens = 0 # Track the total tokens generated for this example
# break while total_tokens + tokens_to_generate < max_seq_length:
# except: input_text, answer = generate_input_output(0, num_docs, qas=qas, docs=docs)
# if used_docs > incremental: # Calculate the number of tokens in the example
# used_docs -= incremental total_tokens = len(tokenizer(input_text + f" {answer}").input_ids)
# print(
# if remove_newline_tab: f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Docs: {num_docs}"
# input_text = " ".join( )
# input_text.replace("\n", " ").replace("\t", " ").strip().split() if total_tokens + tokens_to_generate > max_seq_length:
# ) num_docs -= incremental
# break
# formatted_output = {
# "index": index, num_docs += incremental
# "input": input_text, if num_docs > len(docs):
# "outputs": answer, num_docs = len(docs)
# "length": length, break
# } print("Number of documents:", num_docs)
# write_jsons.append(formatted_output)
# # Generate samples
# return write_jsons for index in tqdm(range(num_samples)):
used_docs = num_docs
while True:
try:
input_text, answer = generate_input_output(
index + pre_samples, used_docs, qas=qas, docs=docs
)
length = len(tokenizer(input_text).input_ids) + tokens_to_generate
assert length <= max_seq_length, f"{length} exceeds max_seq_length."
break
except: # noqa: E722
if used_docs > incremental:
used_docs -= incremental
if remove_newline_tab:
input_text = " ".join(
input_text.replace("\n", " ").replace("\t", " ").strip().split()
)
formatted_output = {
"index": index,
"input": input_text,
"outputs": answer,
"length": length,
"max_seq_length": max_seq_length,
"gen_prefix": "Answer:",
}
write_jsons.append(formatted_output)
return write_jsons
@cache
def get_tokenizer(pretrained):
return AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
def get_dataset(pretrained, docs, qas, max_seq_length=None, **kwargs):
tokenizer = get_tokenizer(pretrained)
write_jsons = generate_samples(
tokenizer=tokenizer,
docs=docs,
qas=qas,
num_samples=500,
tokens_to_generate=128,
max_seq_length=max_seq_length,
)
return write_jsons
def get_qa_dataset(**kwargs):
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
qas, docs = read_squad()
df = (
get_dataset(pretrained, docs=docs, qas=qas, max_seq_length=seq)
for seq in SEQ_LENGTHS
)
return {
"test": datasets.Dataset.from_list(
list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST
)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment