"docs/vscode:/vscode.git/clone" did not exist on "96f1b8ef751872cfe542e2a762f9b6fab7a69659"
Unverified Commit 80a10075 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Add loncxt tasks (#2629)

suport for longcontext (and other synthetic tasks)
* add ruler
* add longbench
* pass `metadata` to TaskConfig
parent f47ddaf8
import itertools
import logging
from typing import Generator
import datasets
from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
from lm_eval.tasks.ruler.prepare_niah import generate_samples, get_haystack
TEMPLATE = """Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text?"""
eval_logger = logging.getLogger(__name__)
def download_dataset(df: Generator) -> dict[str, datasets.Dataset]:
return {
"test": datasets.Dataset.from_list(
list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST
)
}
def niah_single_1(**kwargs):
seq_lengths = kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
return download_dataset(
generate_samples(
get_haystack(type_haystack="repeat"),
max_seq_length=seq,
template=TEMPLATE,
type_haystack="repeat",
type_needle_k="words",
type_needle_v="numbers",
num_samples=500,
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in seq_lengths
)
def niah_single_2(**kwargs):
seq_lengths = kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
return download_dataset(
generate_samples(
get_haystack(type_haystack="essay"),
max_seq_length=seq,
template=TEMPLATE,
type_haystack="essay",
type_needle_k="words",
type_needle_v="numbers",
num_samples=500,
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in seq_lengths
)
def niah_single_3(**kwargs):
seq_lengths = kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
return download_dataset(
generate_samples(
get_haystack(type_haystack="essay"),
max_seq_length=seq,
template=TEMPLATE,
type_haystack="essay",
type_needle_k="words",
type_needle_v="uuids",
num_samples=500,
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in seq_lengths
)
def niah_multikey_1(**kwargs):
seq_lengths = kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
return download_dataset(
generate_samples(
get_haystack(type_haystack="essay"),
max_seq_length=seq,
template=TEMPLATE,
type_haystack="essay",
type_needle_k="words",
type_needle_v="numbers",
num_needle_k=4,
num_samples=500,
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in seq_lengths
)
def niah_multikey_2(**kwargs):
seq_lengths = kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
return download_dataset(
generate_samples(
get_haystack(type_haystack="needle"),
max_seq_length=seq,
template=TEMPLATE,
type_haystack="needle",
type_needle_k="words",
type_needle_v="numbers",
num_samples=500,
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in seq_lengths
)
def niah_multikey_3(**kwargs):
seq_lengths = kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
return download_dataset(
generate_samples(
get_haystack(type_haystack="needle"),
max_seq_length=seq,
template=TEMPLATE,
type_haystack="needle",
type_needle_k="uuids",
type_needle_v="uuids",
num_samples=500,
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in seq_lengths
)
def niah_multivalue(**kwargs):
seq_lengths = kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
return download_dataset(
generate_samples(
get_haystack(type_haystack="essay"),
max_seq_length=seq,
template=TEMPLATE,
type_haystack="essay",
type_needle_k="words",
type_needle_v="numbers",
num_needle_v=4,
num_samples=500,
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in seq_lengths
)
def niah_multiquery(**kwargs):
seq_lengths = kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
return download_dataset(
generate_samples(
get_haystack(type_haystack="essay"),
max_seq_length=seq,
template=TEMPLATE,
type_haystack="essay",
type_needle_k="words",
type_needle_v="numbers",
num_needle_q=4,
num_samples=500,
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in seq_lengths
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import os
import random
import re
import uuid
from functools import lru_cache, cache
from typing import List, Union, Literal
import datasets
import numpy as np
from packaging.version import parse as parse_version
from importlib.metadata import version
from tqdm import tqdm
try:
import wonderwords
import nltk
from nltk import sent_tokenize
except ImportError:
raise ImportError(
'Please install the `wonderwords` and `nltk` packages to run this script. You can install them with `pip install lm_eval["ruler"]` or`pip install wonderwords nltk`.'
)
NUM_SAMPLES = 500
REMOVE_NEWLINE_TAB = ""
STOP_WORDS = ""
RANDOM_SEED = 42
# Define Needle/Haystack Format
NEEDLE = "One of the special magic {type_needle_v} for {key} is: {value}."
# Words
r = wonderwords.RandomWord()
nouns = r._categories["nouns"]
adjs = r._categories["adjectives"]
verbs = r._categories["verbs"]
words = [f"{adj}-{noun}" for adj in adjs for noun in nouns]
WORDS = sorted(list(set(words)))
# Positions
DEPTHS = list(np.round(np.linspace(0, 100, num=40, endpoint=True)).astype(int))
NLTK_MIN_VERSION = "3.9.1"
RANK = os.environ.get("LOCAL_RANK", "0")
@lru_cache(maxsize=1024)
def cached_sent_tokenize(text: str) -> List[str]:
return sent_tokenize(text)
def download_nltk_resources():
"""Download 'punkt' if not already installed"""
assert (nltk_version := parse_version(version("nltk"))) >= parse_version(
NLTK_MIN_VERSION
), (
f"`nltk` version {nltk_version} is not >= {NLTK_MIN_VERSION}. Please update `nltk` before proceeding--older versions are vulnerable to a remote code execution vulnerability."
)
try:
nltk.data.find("tokenizers/punkt_tab")
except LookupError:
if RANK == "0":
nltk.download("punkt_tab")
print("Downloaded punkt_tab on rank 0")
download_nltk_resources()
def generate_random_number(num_digits=7) -> str:
lower_bound = 10 ** (num_digits - 1)
upper_bound = 10**num_digits - 1
return str(random.randint(lower_bound, upper_bound))
def generate_random_word() -> str:
word = random.choice(WORDS)
return word
def generate_random_uuid() -> str:
return str(uuid.UUID(int=random.getrandbits(128), version=4))
def generate_random(type_needle: str) -> str:
if type_needle == "numbers":
return generate_random_number()
elif type_needle == "words":
return generate_random_word()
elif type_needle == "uuids":
return generate_random_uuid()
else:
raise NotImplementedError(f"{type_needle} is not implemented.")
def generate_input_output(
num_haystack: int,
haystack: Union[list[str], str],
*,
type_haystack: str,
num_needle_k: int,
type_needle_k: str,
num_needle_v: int,
type_needle_v: str,
template: str,
num_needle_q: int = 1,
random_seed: int = RANDOM_SEED,
) -> tuple[str, list[str], str]:
NEEDLE = "One of the special magic {type_needle_v} for {key} is: {value}."
keys, values, needles = [], [], []
for _ in range(num_needle_k):
keys.append(generate_random(type_needle_k))
value = []
for _ in range(num_needle_v):
value.append(generate_random(type_needle_v))
needles.append(
NEEDLE.format(
type_needle_v=type_needle_v,
key=keys[-1],
value=value[-1],
)
)
values.append(value)
random.Random(random_seed).shuffle(needles)
# Context
if type_haystack == "essay":
assert isinstance(haystack, list)
text = " ".join(haystack[:num_haystack])
document_sents = cached_sent_tokenize(text.strip())
insertion_positions = (
[0]
+ sorted(
[
int(len(document_sents) * (depth / 100))
for depth in random.sample(DEPTHS, len(needles))
]
)
+ [len(document_sents)]
)
document_sents_list = []
for i in range(1, len(insertion_positions)):
last_pos = insertion_positions[i - 1]
next_pos = insertion_positions[i]
document_sents_list.append(" ".join(document_sents[last_pos:next_pos]))
if i - 1 < len(needles):
document_sents_list.append(needles[i - 1])
context = " ".join(document_sents_list)
else:
if type_haystack == "repeat":
sentences = [haystack] * num_haystack
elif type_haystack == "needle":
sentences = [
haystack.format(
type_needle_v=type_needle_v,
key=generate_random(type_needle_k),
value=generate_random(type_needle_v),
)
for _ in range(num_haystack)
]
indexes = sorted(random.sample(range(num_haystack), len(needles)), reverse=True)
for index, element in zip(indexes, needles):
sentences.insert(index, element)
context = "\n".join(sentences)
## Query and Answer
indices = random.sample(range(num_needle_k), num_needle_q)
queries = [keys[i] for i in indices]
answers = [a for i in indices for a in values[i]]
query = (
", ".join(queries[:-1]) + ", and " + queries[-1]
if len(queries) > 1
else queries[0]
)
template = template
type_needle_v = type_needle_v
if num_needle_q * num_needle_v == 1:
template = template.replace("Some", "A")
template = template.replace("are all", "is")
template = template.replace("are", "is")
template = template.replace("answers", "answer")
type_needle_v = type_needle_v[:-1] # remove "s"
input_text = template.format(
type_needle_v=type_needle_v,
context=context,
query=query,
)
return input_text, answers, query
def generate_samples(
haystack,
TOKENIZER=None,
*,
max_seq_length: int,
type_haystack: str,
type_needle_k: str,
type_needle_v: str,
template: str,
num_samples: int = 500,
tokens_to_generate: int = 128,
num_needle_v: int = 1,
num_needle_k: int = 1,
num_needle_q=1,
incremental: int = 500,
remove_newline_tab: bool = False,
random_seed: int = 42,
) -> list[dict]:
assert TOKENIZER is not None, "TOKENIZER is not defined."
num_needle_k = max(num_needle_k, num_needle_q)
write_jsons = []
tokens_to_generate = tokens_to_generate
if type_haystack == "essay":
incremental = 500
elif type_haystack == "repeat":
incremental = 25
elif type_haystack == "needle":
incremental = 25
if type_haystack != "essay" and max_seq_length < 4096:
incremental = 5
num_haystack = incremental
total_tokens = 0 # Track the total tokens generated for the first example
while total_tokens + tokens_to_generate < max_seq_length:
input_text, answer, query = generate_input_output(
num_haystack,
haystack,
type_haystack=type_haystack,
num_needle_k=num_needle_k,
type_needle_k=type_needle_k,
num_needle_v=num_needle_v,
type_needle_v=type_needle_v,
template=template,
num_needle_q=num_needle_q,
random_seed=random_seed,
)
# Calculate the number of tokens in the example
total_tokens = len(TOKENIZER(input_text + " ".join(answer)).input_ids)
if total_tokens + tokens_to_generate > max_seq_length:
num_haystack -= incremental
break
if type_haystack == "essay" and num_haystack > len(haystack):
num_haystack = len(haystack)
break
num_haystack += incremental
# print("Num haystack:", num_haystack)
# Generate samples
for index in tqdm(
range(num_samples),
desc=f"Generating synthetic samples: {type_haystack} | {max_seq_length}",
):
used_haystack = num_haystack
while True:
try:
input_text, answer, query = generate_input_output(
used_haystack,
haystack,
type_haystack=type_haystack,
num_needle_k=num_needle_k,
type_needle_k=type_needle_k,
num_needle_v=num_needle_v,
type_needle_v=type_needle_v,
template=template,
num_needle_q=num_needle_q,
random_seed=random_seed,
)
length = len(TOKENIZER(input_text).input_ids) + tokens_to_generate
assert length <= max_seq_length, f"{length} exceeds max_seq_length."
break
# ruff: noqa
except:
if used_haystack > incremental:
used_haystack -= incremental
if remove_newline_tab:
input_text = " ".join(
input_text.replace("\n", " ").replace("\t", " ").strip().split()
)
formatted_output = {
"index": index,
"input": input_text,
"outputs": answer,
"length": length,
"max_length": max_seq_length,
"gen_prefix": f"The special magic {type_needle_v[:-1]} for {query} mentioned in the provided text is"
if num_needle_q * num_needle_v == 1
else f"The special magic {type_needle_v} for {query} mentioned in the provided text are",
}
if formatted_output["outputs"][0] not in formatted_output["input"]:
assert False, (
f"Needle not in input: {formatted_output}. Something went wrong."
)
write_jsons.append(formatted_output)
return write_jsons
@cache
def get_haystack(
type_haystack: Literal["essay", "repeat", "needle"],
) -> Union[list[str], str]:
NEEDLE = "One of the special magic {type_needle_v} for {key} is: {value}."
if type_haystack == "essay":
essay = datasets.load_dataset("baber/paul_graham_essays", split="train")["text"]
essay = " ".join(essay)
haystack = re.sub(r"\s+", " ", essay).split(" ")
elif type_haystack == "repeat":
haystack = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
elif type_haystack == "needle":
haystack = NEEDLE
else:
raise NotImplementedError(f"{type_haystack} is not implemented.")
return haystack
include: qa_squad.yaml
task: ruler_qa_hotpot
custom_dataset: !function qa_utils.get_hotpotqa
include: niah_single_1.yaml
task: ruler_qa_squad
custom_dataset: !function qa_utils.get_squad
process_results: !function common_utils.process_results_part
test_split: test
generation_kwargs:
do_sample: false
temperature: 0.0
max_gen_toks: 32
until: []
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import itertools # noqa: I001
import random
from functools import cache
import datasets
import requests
from tqdm import tqdm
from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
CONFIG = {
"tokens_to_generate": 32,
"template": """Answer the question based on the given documents. Only give me the answer and do not output any other words.\n\nThe following are given documents.\n\n{context}\n\nAnswer the question based on the given documents. Only give me the answer and do not output any other words.\n\nQuestion: {query}""",
"answer_prefix": """Answer:""",
}
SEED = 42
TEMPLATE = CONFIG["template"]
DOCUMENT_PROMPT = "Document {i}:\n{document}"
@cache
def download_json(url) -> dict:
response = requests.get(url)
response.raise_for_status()
data = response.json()
return data
@cache
def read_squad(
url="https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json",
) -> tuple[list[dict], list[str]]:
data = download_json(url)
total_docs = [p["context"] for d in data["data"] for p in d["paragraphs"]]
total_docs = sorted(list(set(total_docs)))
total_docs_dict = {c: idx for idx, c in enumerate(total_docs)}
total_qas = []
for d in data["data"]:
more_docs = [total_docs_dict[p["context"]] for p in d["paragraphs"]]
for p in d["paragraphs"]:
for qas in p["qas"]:
if not qas["is_impossible"]:
total_qas.append(
{
"query": qas["question"],
"outputs": [a["text"] for a in qas["answers"]],
"context": [total_docs_dict[p["context"]]],
"more_context": [
idx
for idx in more_docs
if idx != total_docs_dict[p["context"]]
],
}
)
return total_qas, total_docs
@cache
def read_hotpotqa(
url="http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_distractor_v1.json",
) -> tuple[list[dict], list[str]]:
data = download_json(url)
total_docs = [f"{t}\n{''.join(p)}" for d in data for t, p in d["context"]]
total_docs = sorted(list(set(total_docs)))
total_docs_dict = {c: idx for idx, c in enumerate(total_docs)}
total_qas = []
for d in data:
total_qas.append(
{
"query": d["question"],
"outputs": [d["answer"]],
"context": [
total_docs_dict[f"{t}\n{''.join(p)}"] for t, p in d["context"]
],
}
)
return total_qas, total_docs
def generate_input_output(
index: int, num_docs: int, qas: list[dict], docs: list[str]
) -> tuple[str, list[str]]:
curr_q: str = qas[index]["query"]
curr_a: list[str] = qas[index]["outputs"]
curr_docs: list[int] = qas[index]["context"]
curr_more: list[int] = qas[index].get("more_context", [])
if num_docs < len(docs):
if (num_docs - len(curr_docs)) > len(curr_more):
addition_docs = [
i for i, d in enumerate(docs) if i not in curr_docs + curr_more
]
all_docs = (
curr_docs
+ curr_more
+ random.sample(
addition_docs, max(0, num_docs - len(curr_docs) - len(curr_more))
)
)
else:
all_docs = curr_docs + random.sample(curr_more, num_docs - len(curr_docs))
all_docs = [docs[idx] for idx in all_docs]
else:
all_docs = docs
random.Random(SEED).shuffle(all_docs)
context = "\n\n".join(
[DOCUMENT_PROMPT.format(i=i + 1, document=d) for i, d in enumerate(all_docs)]
)
input_text = TEMPLATE.format(context=context, query=curr_q)
return input_text, curr_a
def generate_samples(
tokenizer,
docs: list[str],
qas: list[dict],
max_seq_length: int,
num_samples: int = 500,
tokens_to_generate: int = 32,
pre_samples: int = 0,
incremental: int = 10,
remove_newline_tab=False,
) -> list[dict]:
write_jsons = []
tokens_to_generate = tokens_to_generate
# Find the perfect num_docs
num_docs = incremental
total_tokens = 0 # Track the total tokens generated for this example
while total_tokens + tokens_to_generate < max_seq_length:
input_text, answer = generate_input_output(0, num_docs, qas=qas, docs=docs)
# Calculate the number of tokens in the example
total_tokens = len(tokenizer(input_text + f" {answer}").input_ids)
# print(
# f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Docs: {num_docs}"
# )
if total_tokens + tokens_to_generate > max_seq_length:
num_docs -= incremental
break
num_docs += incremental
if num_docs > len(docs):
num_docs = len(docs)
break
# print("Number of documents:", num_docs)
# Generate samples
for index in tqdm(
range(num_samples), desc=f"Generating QA Samples | {max_seq_length}"
):
used_docs = num_docs
while True:
try:
input_text, answer = generate_input_output(
index + pre_samples, used_docs, qas=qas, docs=docs
)
length = len(tokenizer(input_text).input_ids) + tokens_to_generate
assert length <= max_seq_length, f"{length} exceeds max_seq_length."
break
except: # noqa: E722
if used_docs > incremental:
used_docs -= incremental
if remove_newline_tab:
input_text = " ".join(
input_text.replace("\n", " ").replace("\t", " ").strip().split()
)
formatted_output = {
"index": index,
"input": input_text,
"outputs": answer,
"length": length,
"max_length": max_seq_length,
"gen_prefix": "Answer:",
}
write_jsons.append(formatted_output)
return write_jsons
def get_dataset(pretrained, docs, qas, max_seq_length=None, **kwargs) -> list[dict]:
tokenizer = get_tokenizer(pretrained)
write_jsons = generate_samples(
tokenizer=tokenizer,
docs=docs,
qas=qas,
num_samples=500,
tokens_to_generate=32,
max_seq_length=max_seq_length,
)
return write_jsons
def get_qa_dataset(ds, **kwargs) -> dict[str, datasets.Dataset]:
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
if ds == "squad":
qas, docs = read_squad()
else:
qas, docs = read_hotpotqa()
df = (
get_dataset(pretrained=pretrained, docs=docs, qas=qas, max_seq_length=seq)
for seq in kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
)
return {
"test": datasets.Dataset.from_list(
list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST
)
}
def get_squad(**kwargs):
return get_qa_dataset("squad", **kwargs)
def get_hotpotqa(**kwargs):
return get_qa_dataset("hotpotqa", **kwargs)
group: ruler
task:
- niah_single_1
- niah_single_2
- niah_single_3
- niah_multikey_1
- niah_multikey_2
- niah_multikey_3
- niah_multiquery
- niah_multivalue
- ruler_vt
- ruler_cwe
- ruler_fwe
- ruler_qa_squad
- ruler_qa_hotpot
aggregate_metric_list:
- metric: "4096"
weight_by_size: False
metadata:
version: 1
include: niah_single_1.yaml
task: ruler_vt
custom_dataset: !function vt_utils.get_vt_dataset
generation_kwargs:
do_sample: false
temperature: 0.0
max_gen_toks: 30
until: []
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# adapted from https://github.com/NVIDIA/RULER/blob/main/scripts/data/synthetic/variable_tracking.py
import itertools
import random
import string
from typing import TYPE_CHECKING, Union
import datasets
import numpy as np
from tqdm import tqdm
from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
CONFIG = {
"variable_tracking": {
"tokens_to_generate": 30,
"template": """Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n{context}\nQuestion: Find all variables that are assigned the value {query} in the text above.""",
"answer_prefix": """ Answer: According to the chain(s) of variable assignment in the text above, {num_v} variables are assgined the value {query}, they are: """,
},
}
TEMPLATE = (
CONFIG["variable_tracking"]["template"]
+ CONFIG["variable_tracking"]["answer_prefix"]
)
def generate_chains(
num_chains: int, num_hops: int, is_icl: bool = False
) -> tuple[list[list[str]], list[list[str]]]:
vars_all = []
k = 5 if not is_icl else 3
num_hops = num_hops if not is_icl else min(10, num_hops)
vars_all = [
"".join(random.choices(string.ascii_uppercase, k=k)).upper()
for _ in range((num_hops + 1) * num_chains)
]
while len(set(vars_all)) < num_chains * (num_hops + 1):
vars_all.append("".join(random.choices(string.ascii_uppercase, k=k)).upper())
vars_ret = []
chains_ret = []
for i in range(0, len(vars_all), num_hops + 1):
this_vars = vars_all[i : i + num_hops + 1]
vars_ret.append(this_vars)
this_chain = [f"VAR {this_vars[0]} = {np.random.randint(10000, 99999)}"]
for j in range(num_hops):
this_chain.append(f"VAR {this_vars[j + 1]} = VAR {this_vars[j]} ")
chains_ret.append(this_chain)
return vars_ret, chains_ret
def generate_input_output(num_noises, num_chains, num_hops, is_icl=False):
vars, chains = generate_chains(num_chains, num_hops, is_icl=is_icl)
noise = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.\n"
# Create a list of the repeated noise
sentences = [noise] * num_noises
if len(sentences) <= len(chains[0]):
sentences = [
n + "." if len(n.strip()) > 0 else n
for n in [x for noise in sentences for x in noise.split(".")]
]
try:
assert len(sentences) > len(chains[0]), (
"Noises too short, unable to generate data"
)
except: # noqa: E722
print("reduces chain length for not enough noises")
chains = [chain[: len(sentences) - 1] for chain in chains]
# sample random positions to insert variable assignment
for chain_i in chains:
# sample random positions (sorted) to insert variable assignment
positions = list(sorted(random.sample(range(len(sentences)), len(chain_i))))
for insert_pi, j in zip(positions, range(len(chain_i))):
sentences.insert(insert_pi + j, chain_i[j])
# Insert the passkey sentence at the random position
context = " ".join(sentences)
context = context.replace(". \n", ".\n")
template = TEMPLATE
if (
is_icl
and template
!= CONFIG["variable_tracking"]["template"]
+ CONFIG["variable_tracking"]["answer_prefix"]
):
# remove model template
cutoff = template.index(CONFIG["variable_tracking"]["template"][:20])
cutoff_ans = template.index(CONFIG["variable_tracking"]["answer_prefix"][:10])
template = (
" ".join(template[cutoff:cutoff_ans].split()[:-1]) + template[cutoff_ans:]
)
value = chains[0][0].split("=")[-1].strip()
input_text = template.format(context=context, query=value, num_v=num_hops + 1)
return input_text, vars[0]
def randomize_icl(icl_example: str) -> str:
icl_tgt_cut = icl_example.index(CONFIG["variable_tracking"]["answer_prefix"][-10:])
icl_tgt = icl_example[icl_tgt_cut + 10 :].strip().split()
for item in icl_tgt:
new_item = "".join(random.choices(string.ascii_uppercase, k=len(item))).upper()
icl_example = icl_example.replace(item, new_item)
return icl_example
def sys_vartrack_w_noise_random(
tokenizer,
num_samples: int,
max_seq_length: int,
incremental: int = 10,
num_chains: int = 1,
num_hops: int = 4,
add_fewshot: bool = True,
tokens_to_generate=30,
icl_example: dict = None,
remove_newline_tab=False,
):
write_jsons = []
tokens_to_generate = tokens_to_generate
# Find the perfect num_noises
num_noises = incremental
total_tokens = 0 # Track the total tokens generated for this example
example_tokens = 0
if add_fewshot and (icl_example is not None):
icl_example_out = " ".join(icl_example["outputs"])
icl_example = icl_example["input"] + " " + icl_example_out + "\n\n"
example_tokens = len(tokenizer(icl_example).input_ids)
while total_tokens + tokens_to_generate + example_tokens < max_seq_length:
input_text, answer = generate_input_output(
num_noises, num_chains, num_hops, is_icl=add_fewshot & (icl_example is None)
)
# Calculate the number of tokens in the example
total_tokens = len(tokenizer(input_text + f" {answer}").input_ids)
print(
f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate + example_tokens} | Noises: {num_noises}"
)
if total_tokens + tokens_to_generate + example_tokens > max_seq_length:
num_noises -= incremental
break
num_noises += incremental
print("Num noises:", num_noises)
# Generate samples
for index in tqdm(range(num_samples)):
used_noises = num_noises
while True:
try:
input_text, answer = generate_input_output(
used_noises,
num_chains,
num_hops,
is_icl=add_fewshot & (icl_example is None),
)
length = (
len(tokenizer(input_text).input_ids)
+ tokens_to_generate
+ example_tokens
)
assert length <= max_seq_length, f"{length} exceeds max_seq_length."
break
except: # noqa: E722
if used_noises > incremental:
used_noises -= incremental
if add_fewshot and (icl_example is not None):
# insert icl_example between model template and input
cutoff = input_text.index(CONFIG["variable_tracking"]["template"][:20])
input_text = (
input_text[:cutoff]
+ randomize_icl(icl_example)
+ "\n\n"
+ input_text[cutoff:]
)
if remove_newline_tab:
input_text = " ".join(
input_text.replace("\n", " ").replace("\t", " ").strip().split()
)
gen_prefix_index = input_text.rfind(
" Answer: According to the chain(s) of variable assignment"
)
gen_prefix = input_text[gen_prefix_index:].strip()
# This condition is to check if we are generating the few-shot.
if icl_example is not None:
input_text = input_text[:gen_prefix_index]
formatted_output = {
"index": index,
"input": input_text,
"outputs": answer,
"length": length,
"max_length": max_seq_length,
"gen_prefix": gen_prefix.strip(),
}
write_jsons.append(formatted_output)
return write_jsons
def get_dataset(
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
seq=None,
**kwargs,
) -> list[dict]:
icl_example = sys_vartrack_w_noise_random(
tokenizer=tokenizer,
num_samples=1,
max_seq_length=500,
incremental=5,
)[0]
write_jsons = sys_vartrack_w_noise_random(
tokenizer=tokenizer,
num_samples=500,
max_seq_length=seq,
icl_example=icl_example,
)
return write_jsons
def get_vt_dataset(**kwargs) -> dict[str, datasets.Dataset]:
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", ""))
df = (
get_dataset(tokenizer=get_tokenizer(pretrained), seq=seq)
for seq in kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS)
)
return {
"test": datasets.Dataset.from_list(
list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST
)
}
...@@ -11,7 +11,7 @@ import re ...@@ -11,7 +11,7 @@ import re
from dataclasses import asdict, is_dataclass from dataclasses import asdict, is_dataclass
from itertools import islice from itertools import islice
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Generator, List, Tuple from typing import Any, Callable, Generator, List, Optional, Tuple
import numpy as np import numpy as np
import yaml import yaml
...@@ -28,6 +28,17 @@ HIGHER_IS_BETTER_SYMBOLS = { ...@@ -28,6 +28,17 @@ HIGHER_IS_BETTER_SYMBOLS = {
def setup_logging(verbosity=logging.INFO): def setup_logging(verbosity=logging.INFO):
# Configure the root logger # Configure the root logger
class CustomFormatter(logging.Formatter):
def format(self, record):
if record.name.startswith("lm_eval."):
record.name = record.name[len("lm_eval.") :]
return super().format(record)
formatter = CustomFormatter(
"%(asctime)s %(levelname)-8s [%(name)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
)
log_level = os.environ.get("LOGLEVEL", verbosity) or verbosity log_level = os.environ.get("LOGLEVEL", verbosity) or verbosity
level_map = { level_map = {
...@@ -39,12 +50,15 @@ def setup_logging(verbosity=logging.INFO): ...@@ -39,12 +50,15 @@ def setup_logging(verbosity=logging.INFO):
} }
log_level = level_map.get(str(log_level).upper(), logging.INFO) log_level = level_map.get(str(log_level).upper(), logging.INFO)
if not logging.root.handlers: if not logging.root.handlers:
logging.basicConfig( handler = logging.StreamHandler()
format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(name)s:%(lineno)d] %(message)s", handler.setFormatter(formatter)
datefmt="%Y-%m-%d:%H:%M:%S",
level=log_level, root_logger = logging.getLogger()
) root_logger.addHandler(handler)
root_logger.setLevel(log_level)
if log_level == logging.DEBUG: if log_level == logging.DEBUG:
third_party_loggers = ["urllib3", "filelock", "fsspec"] third_party_loggers = ["urllib3", "filelock", "fsspec"]
for logger_name in third_party_loggers: for logger_name in third_party_loggers:
...@@ -114,12 +128,14 @@ def sanitize_list(sub): ...@@ -114,12 +128,14 @@ def sanitize_list(sub):
return str(sub) return str(sub)
def simple_parse_args_string(args_string): def simple_parse_args_string(args_string: Optional[str]) -> dict:
""" """
Parses something like Parses something like
args1=val1,arg2=val2 args1=val1,arg2=val2
Into a dictionary Into a dictionary
""" """
if args_string is None:
return {}
args_string = args_string.strip() args_string = args_string.strip()
if not args_string: if not args_string:
return {} return {}
...@@ -158,13 +174,13 @@ def pattern_match(patterns, source_list): ...@@ -158,13 +174,13 @@ def pattern_match(patterns, source_list):
return sorted(list(task_names)) return sorted(list(task_names))
def softmax(x): def softmax(x) -> np.ndarray:
"""Compute softmax values for each sets of scores in x.""" """Compute softmax values for each sets of scores in x."""
e_x = np.exp(x - np.max(x)) e_x = np.exp(x - np.max(x))
return e_x / e_x.sum() return e_x / e_x.sum()
def general_detokenize(string): def general_detokenize(string) -> str:
string = string.replace(" n't", "n't") string = string.replace(" n't", "n't")
string = string.replace(" )", ")") string = string.replace(" )", ")")
string = string.replace("( ", "(") string = string.replace("( ", "(")
......
...@@ -65,12 +65,14 @@ gptq = ["auto-gptq[triton]>=0.6.0"] ...@@ -65,12 +65,14 @@ gptq = ["auto-gptq[triton]>=0.6.0"]
hf_transfer = ["hf_transfer"] hf_transfer = ["hf_transfer"]
ibm_watsonx_ai = ["ibm_watsonx_ai>=1.1.22", "python-dotenv"] ibm_watsonx_ai = ["ibm_watsonx_ai>=1.1.22", "python-dotenv"]
ifeval = ["langdetect", "immutabledict", "nltk>=3.9.1"] ifeval = ["langdetect", "immutabledict", "nltk>=3.9.1"]
longbench=["jeiba", "fuzzywuzzy", "rouge"]
neuronx = ["optimum[neuronx]"] neuronx = ["optimum[neuronx]"]
mamba = ["mamba_ssm", "causal-conv1d==1.0.2"] mamba = ["mamba_ssm", "causal-conv1d==1.0.2"]
math = ["sympy>=1.12", "antlr4-python3-runtime==4.11", "math_verify[antlr4_11_0]"] math = ["sympy>=1.12", "antlr4-python3-runtime==4.11", "math_verify[antlr4_11_0]"]
multilingual = ["nagisa>=0.2.7", "jieba>=0.42.1", "pycountry"] multilingual = ["nagisa>=0.2.7", "jieba>=0.42.1", "pycountry"]
optimum = ["optimum[openvino]"] optimum = ["optimum[openvino]"]
promptsource = ["promptsource>=0.2.3"] promptsource = ["promptsource>=0.2.3"]
ruler = ["nltk", "wonderwords", "scipy"]
sae_lens = ["sae_lens"] sae_lens = ["sae_lens"]
sentencepiece = ["sentencepiece>=0.1.98"] sentencepiece = ["sentencepiece>=0.1.98"]
sparsify = ["sparsify"] sparsify = ["sparsify"]
...@@ -89,11 +91,13 @@ all = [ ...@@ -89,11 +91,13 @@ all = [
"lm_eval[hf_transfer]", "lm_eval[hf_transfer]",
"lm_eval[ibm_watsonx_ai]", "lm_eval[ibm_watsonx_ai]",
"lm_eval[ifeval]", "lm_eval[ifeval]",
"lm_eval[longbench]",
"lm_eval[mamba]", "lm_eval[mamba]",
"lm_eval[math]", "lm_eval[math]",
"lm_eval[multilingual]", "lm_eval[multilingual]",
"lm_eval[openai]", "lm_eval[openai]",
"lm_eval[promptsource]", "lm_eval[promptsource]",
"lm_eval[ruler]",
"lm_eval[sae_lens]", "lm_eval[sae_lens]",
"lm_eval[sentencepiece]", "lm_eval[sentencepiece]",
"lm_eval[sparsify]", "lm_eval[sparsify]",
......
...@@ -44,3 +44,5 @@ If other tasks on this dataset are already supported: ...@@ -44,3 +44,5 @@ If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted? * [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates? * [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant? * [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
### Changelog
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment