Commit 3bbc73d4 authored by Baber's avatar Baber
Browse files

cleanup

parent 2c8a1656
import re
from functools import cache
from typing import TYPE_CHECKING, Union
from transformers import AutoTokenizer
if TYPE_CHECKING:
import transformers
SEQ_LENGTHS = (
# 131072,
# 65536,
# 32768,
# 16384,
# 8192,
4096,
)
@cache
def get_tokenizer(
tokenizer=None, pretrained=None, **kwargs
) -> Union["transformers.PreTrainedTokenizer", "transformers.PreTrainedTokenizerFast"]:
pretrained = tokenizer or pretrained
assert pretrained, "No tokenizer or pretrained provided."
print("using tokenizer ", pretrained)
return AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
def postprocess_pred(predict_str: str) -> str:
predict_str = predict_str.strip()
# Remove all non-printable characters
np_pattern = re.compile(r"[\x00-\x1f]")
predict_str = np_pattern.sub("\n", predict_str).strip()
return predict_str
def string_match_all(preds: list[str], refs: list[list[str]]) -> float:
score = sum(
[
sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) / len(ref)
for pred, ref in zip(preds, refs)
]
) / len(preds)
return score
def string_match_part(preds: list[str], refs: list[list[str]]) -> float:
score = max(
[
sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) / len(ref)
for pred, ref in zip(preds, refs)
]
) / len(preds)
return score
def process_results(doc: dict, results: list[str]) -> dict[str, float]:
# hacky: set all other lengths to -1
metrics = {str(length): -1.0 for length in SEQ_LENGTHS}
input_len = doc["max_length"]
pred = postprocess_pred(results[0])
score = string_match_all([pred], [doc["outputs"]])
metrics[str(input_len)] = score
return metrics
def process_results_part(doc: dict, results: list[str]) -> dict[str, float]:
# hacky: set all other lengths to -1
metrics = {str(length): -1.0 for length in SEQ_LENGTHS}
input_len = doc["max_length"]
pred = postprocess_pred(results[0])
score = string_match_part([pred], [doc["outputs"]])
metrics[str(input_len)] = score
return metrics
def aggregate_metrics(metrics: list[float]) -> float:
res = [x for x in metrics if x != -1]
if not res:
# we don't have any samples with this length
return 0.0
return sum(res) / len(res)
......@@ -13,14 +13,12 @@
# limitations under the License
import itertools
import random
from functools import cache
import datasets
import wonderwords
from tqdm import tqdm
from transformers import AutoTokenizer
from lm_eval.tasks.ruler.utils import SEQ_LENGTHS
from lm_eval.tasks.ruler.common_utils import SEQ_LENGTHS, get_tokenizer
RNG = random.Random(42)
......@@ -107,9 +105,9 @@ def sys_word_pair_random(
+ " ".join([f"{i + 1}. {word}" for i, word in enumerate(answer)])
).input_ids
)
print(
f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Words: {num_words}"
)
# print(
# f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Words: {num_words}"
# )
if total_tokens + tokens_to_generate > max_seq_length:
num_words -= incremental
break
......@@ -119,10 +117,12 @@ def sys_word_pair_random(
num_words = len(WORDS)
break
print("num_words:", num_words)
# print("num_words:", num_words)
# Generate samples
for index in tqdm(range(num_samples)):
for index in tqdm(
range(num_samples), desc=f"Generating CWE Samples | {max_seq_length}"
):
used_words = num_words
while True:
try:
......@@ -161,11 +161,6 @@ def sys_word_pair_random(
return write_jsons
@cache
def get_tokenizer(pretrained):
return AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
def get_dataset(pretrained, seq=None, **kwargs):
tokenizer = get_tokenizer(pretrained)
write_jsons = sys_word_pair_random(
......
......@@ -23,12 +23,14 @@ from bs4 import BeautifulSoup
from tqdm.asyncio import tqdm as async_tqdm
@cache
async def fetch_url(client: httpx.AsyncClient, url: str) -> str:
response = await client.get(url)
response.raise_for_status()
return response.text
@cache
async def process_html_essay(
client: httpx.AsyncClient, url: str, h: html2text.HTML2Text, temp_folder: str
) -> None:
......@@ -50,6 +52,7 @@ async def process_html_essay(
print(f"Failed to download {filename}: {str(e)}")
@cache
async def process_text_essay(
client: httpx.AsyncClient, url: str, temp_folder: str
) -> None:
......@@ -64,6 +67,7 @@ async def process_text_essay(
print(f"Failed to download {filename}: {str(e)}")
@cache
async def get_essays() -> Dict[str, str]:
temp_folder_repo = "essay_repo"
temp_folder_html = "essay_html"
......
......@@ -14,16 +14,14 @@
import itertools
import random
import string
from functools import cache
import datasets
import numpy as np
import transformers
from scipy.special import zeta
from tqdm import tqdm
from transformers import AutoTokenizer
from lm_eval.tasks.ruler.utils import SEQ_LENGTHS
from lm_eval.tasks.ruler.common_utils import SEQ_LENGTHS, get_tokenizer
config = (
......@@ -112,9 +110,11 @@ def sys_kwext(
incremental=input_max_len // 32,
alpha=alpha,
)
print("num_example_words:", num_example_words)
# print("num_example_words:", num_example_words)
# Generate samples
for index in tqdm(range(num_samples)):
for index in tqdm(
range(num_samples), desc=f"Generating FWE Samples | {max_seq_length}"
):
# construct input
input_max_len = max_seq_length
input_text, answer, _ = generate_input_output(
......@@ -139,7 +139,7 @@ def sys_kwext(
"input": input_text,
"outputs": answer,
"length": length,
"max_seq_length": max_seq_length,
"max_length": max_seq_length,
"gen_prefix": "Answer: According to the coded text above, the three most frequently appeared words are:",
}
write_jsons.append(formatted_output)
......@@ -147,11 +147,6 @@ def sys_kwext(
return write_jsons
@cache
def get_tokenizer(pretrained):
return AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
def get_dataset(pretrained, max_seq_length=None, **kwargs):
tokenizer = get_tokenizer(pretrained)
write_jsons = sys_kwext(
......
task: niah_multikey_1
include: niah_single_1.yaml
download_dataset: !function utils.niah_multikey_1
download_dataset: !function niah_utils.niah_multikey_1
task: niah_multikey_2
include: niah_single_1.yaml
download_dataset: !function utils.niah_multikey_2
download_dataset: !function niah_utils.niah_multikey_2
task: niah_multikey_3
include: niah_single_1.yaml
download_dataset: !function utils.niah_multikey_3
download_dataset: !function niah_utils.niah_multikey_3
task: niah_multiquery
include: niah_single_1.yaml
download_dataset: !function utils.niah_multiquery
download_dataset: !function niah_utils.niah_multiquery
task: niah_multivalue
include: niah_single_1.yaml
download_dataset: !function utils.niah_multivalue
download_dataset: !function niah_utils.niah_multivalue
......@@ -5,20 +5,20 @@ dataset_path: ""
dataset_name: ""
output_type: generate_until
test_split: test
download_dataset: !function utils.niah_single_1
download_dataset: !function niah_utils.niah_single_1
doc_to_text: "{{input}}"
doc_to_target: "{{outputs[0]}}"
gen_prefix: "{{gen_prefix}}"
process_results: !function utils.process_results
process_results: !function common_utils.process_results
metric_list:
- metric: "4096"
aggregation: !function utils.aggregate_metrics
aggregation: !function common_utils.aggregate_metrics
higher_is_better: true
- metric: "8192"
aggregation: !function utils.aggregate_metrics
aggregation: !function common_utils.aggregate_metrics
higher_is_better: true
- metric: "16384"
aggregation: !function utils.aggregate_metrics
aggregation: !function common_utils.aggregate_metrics
higher_is_better: true
# - metric: "32768"
# aggregation: !function utils.aggregate_metrics
......
task: niah_single_2
include: niah_single_1.yaml
download_dataset: !function utils.niah_single_2
download_dataset: !function niah_utils.niah_single_2
task: niah_single_3
include: niah_single_1.yaml
download_dataset: !function utils.niah_single_3
download_dataset: !function niah_utils.niah_single_3
# noqa
import itertools
import re
from functools import cache
from typing import Literal, Generator, Union, TYPE_CHECKING
from typing import Literal, Union, Generator
import datasets
from transformers import AutoTokenizer
from lm_eval.tasks.ruler.essays import get_all_essays
from lm_eval.tasks.ruler.prepare import generate_samples
if TYPE_CHECKING:
import transformers
from lm_eval.tasks.ruler.prepare import generate_samples, get_haystack
from lm_eval.tasks.ruler.common_utils import SEQ_LENGTHS, get_tokenizer
def get_tokenizer(
**kwargs,
) -> Union["transformers.PreTrainedTokenizer", "transformers.PreTrainedTokenizerFast"]:
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
assert pretrained, "No tokenizer or pretrained provided."
print("using tokenizer ", pretrained)
return AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
# TEMPLATE = """Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text? The special magic {type_needle_v} for {query} mentioned in the provided text are"""
TEMPLATE = """Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text?"""
SEQ_LENGTHS = (
# 131072,
# 65536,
# 32768,
# 16384,
# 8192,
4096,
)
NUM_SAMPLES = 500
REMOVE_NEWLINE_TAB = ""
STOP_WORDS = ""
RANDOM_SEED = 42
@cache
def get_haystack(
type_haystack: Literal["essay", "repeat", "needle"],
) -> Union[list[str], str]:
NEEDLE = "One of the special magic {type_needle_v} for {key} is: {value}."
if type_haystack == "essay":
essay = get_all_essays()["text"]
# essay = json.load(open(essay))["text"]
haystack = re.sub(r"\s+", " ", essay).split(" ")
elif type_haystack == "repeat":
haystack = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
elif type_haystack == "needle":
haystack = NEEDLE
else:
raise NotImplementedError(f"{type_haystack} is not implemented.")
return haystack
def download_dataset(df: Generator) -> dict[str, datasets.Dataset]:
return {
"test": datasets.Dataset.from_list(
......@@ -78,7 +26,7 @@ niah_single_1 = lambda **kwargs: download_dataset(
type_haystack="repeat",
type_needle_k="words",
type_needle_v="numbers",
TOKENIZER=get_tokenizer(**kwargs),
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
)
for seq in SEQ_LENGTHS
)
......@@ -91,7 +39,7 @@ niah_single_2 = lambda **kwargs: download_dataset(
type_haystack="essay",
type_needle_k="words",
type_needle_v="numbers",
TOKENIZER=get_tokenizer(**kwargs),
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
)
for seq in SEQ_LENGTHS
)
......@@ -104,7 +52,7 @@ niah_single_3 = lambda **kwargs: download_dataset(
type_haystack="essay",
type_needle_k="words",
type_needle_v="uuids",
TOKENIZER=get_tokenizer(**kwargs),
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
)
for seq in SEQ_LENGTHS
)
......@@ -118,7 +66,7 @@ niah_multikey_1 = lambda **kwargs: download_dataset(
type_needle_k="words",
type_needle_v="numbers",
num_needle_k=4,
TOKENIZER=get_tokenizer(**kwargs),
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
)
for seq in SEQ_LENGTHS
)
......@@ -131,7 +79,7 @@ niah_multikey_2 = lambda **kwargs: download_dataset(
type_haystack="needle",
type_needle_k="words",
type_needle_v="numbers",
TOKENIZER=get_tokenizer(**kwargs),
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
)
for seq in SEQ_LENGTHS
)
......@@ -144,7 +92,7 @@ niah_multikey_3 = lambda **kwargs: download_dataset(
type_haystack="needle",
type_needle_k="uuids",
type_needle_v="uuids",
TOKENIZER=get_tokenizer(**kwargs),
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
)
for seq in SEQ_LENGTHS
)
......@@ -158,7 +106,7 @@ niah_multivalue = lambda **kwargs: download_dataset(
type_needle_k="words",
type_needle_v="numbers",
num_needle_v=4,
TOKENIZER=get_tokenizer(**kwargs),
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
)
for seq in SEQ_LENGTHS
)
......@@ -172,45 +120,7 @@ niah_multiquery = lambda **kwargs: download_dataset(
type_needle_k="words",
type_needle_v="numbers",
num_needle_q=4,
TOKENIZER=get_tokenizer(**kwargs),
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
)
for seq in SEQ_LENGTHS
)
def postprocess_pred(predict_str: str) -> str:
predict_str = predict_str.strip()
# Remove all non-printable characters
np_pattern = re.compile(r"[\x00-\x1f]")
predict_str = np_pattern.sub("\n", predict_str).strip()
return predict_str
def string_match_all(preds: list[str], refs: list[list[str]]) -> float:
score = sum(
[
sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) / len(ref)
for pred, ref in zip(preds, refs)
]
) / len(preds)
return score
def process_results(doc: dict, results: list[str]) -> dict[str, float]:
# hacky: set all other lengths to -1
metrics = {str(length): -1.0 for length in SEQ_LENGTHS}
input_len = doc["max_length"]
pred = postprocess_pred(results[0])
score = string_match_all([pred], [doc["outputs"]])
metrics[str(input_len)] = score
return metrics
def aggregate_metrics(metrics: list[float]) -> float:
res = [x for x in metrics if x != -1]
if not res:
# we don't have any samples with this length
return 0.0
return sum(res) / len(res)
......@@ -15,10 +15,10 @@
import os
import random
import re
import uuid
from linecache import cache
from functools import lru_cache
from typing import List, Union
from functools import lru_cache, cache
from typing import List, Union, Literal
import numpy as np
import wonderwords
......@@ -29,6 +29,7 @@ from importlib.metadata import version
from tqdm import tqdm
from lm_eval.tasks.ruler.essays import get_all_essays
NUM_SAMPLES = 500
REMOVE_NEWLINE_TAB = ""
......@@ -317,3 +318,21 @@ def generate_samples(
)
write_jsons.append(formatted_output)
return write_jsons
@cache
def get_haystack(
type_haystack: Literal["essay", "repeat", "needle"],
) -> Union[list[str], str]:
NEEDLE = "One of the special magic {type_needle_v} for {key} is: {value}."
if type_haystack == "essay":
essay = get_all_essays()["text"]
# essay = json.load(open(essay))["text"]
haystack = re.sub(r"\s+", " ", essay).split(" ")
elif type_haystack == "repeat":
haystack = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
elif type_haystack == "needle":
haystack = NEEDLE
else:
raise NotImplementedError(f"{type_haystack} is not implemented.")
return haystack
include: niah_single_1.yaml
task: ruler_qa_squad
download_dataset: !function qa_utils.get_squad
process_results: !function common_utils.process_results_part
test_split: test
generation_kwargs:
do_sample: false
......
......@@ -20,9 +20,8 @@ from functools import cache
import datasets
import requests
from tqdm import tqdm
from transformers import AutoTokenizer
from lm_eval.tasks.ruler.utils import SEQ_LENGTHS
from lm_eval.tasks.ruler.common_utils import SEQ_LENGTHS, get_tokenizer
config = {
"tokens_to_generate": 32,
......@@ -154,9 +153,9 @@ def generate_samples(
input_text, answer = generate_input_output(0, num_docs, qas=qas, docs=docs)
# Calculate the number of tokens in the example
total_tokens = len(tokenizer(input_text + f" {answer}").input_ids)
print(
f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Docs: {num_docs}"
)
# print(
# f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Docs: {num_docs}"
# )
if total_tokens + tokens_to_generate > max_seq_length:
num_docs -= incremental
break
......@@ -165,10 +164,12 @@ def generate_samples(
if num_docs > len(docs):
num_docs = len(docs)
break
print("Number of documents:", num_docs)
# print("Number of documents:", num_docs)
# Generate samples
for index in tqdm(range(num_samples)):
for index in tqdm(
range(num_samples), desc=f"Generating QA Samples | {max_seq_length}"
):
used_docs = num_docs
while True:
try:
......@@ -192,7 +193,7 @@ def generate_samples(
"input": input_text,
"outputs": answer,
"length": length,
"max_seq_length": max_seq_length,
"max_length": max_seq_length,
"gen_prefix": "Answer:",
}
write_jsons.append(formatted_output)
......@@ -200,11 +201,6 @@ def generate_samples(
return write_jsons
@cache
def get_tokenizer(pretrained):
return AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
def get_dataset(pretrained, docs, qas, max_seq_length=None, **kwargs):
tokenizer = get_tokenizer(pretrained)
write_jsons = generate_samples(
......@@ -226,7 +222,7 @@ def get_qa_dataset(ds, **kwargs):
else:
qas, docs = read_hotpotqa()
df = (
get_dataset(pretrained, docs=docs, qas=qas, max_seq_length=seq)
get_dataset(pretrained=pretrained, docs=docs, qas=qas, max_seq_length=seq)
for seq in SEQ_LENGTHS
)
......@@ -243,7 +239,3 @@ def get_squad(**kwargs):
def get_hotpotqa(**kwargs):
return get_qa_dataset("hotpotqa", **kwargs)
# get_squad = lambda **kwargs: partial(get_qa_dataset, "squad")(**kwargs)
# get_hotpotqa = lambda **kwargs: partial(get_qa_dataset, "hotpotqa")(**kwargs)
group: ruler
task:
- niah_single_1
- niah_single_2
- niah_single_3
- niah_multikey_1
- niah_multikey_2
- niah_multikey_3
- niah_multiquery
- niah_multivalue
# - niah_single_1
# - niah_single_2
# - niah_single_3
# - niah_multikey_1
# - niah_multikey_2
# - niah_multikey_3
# - niah_multiquery
# - niah_multivalue
- ruler_vt
- ruler_cwe
- ruler_fwe
- ruler_qa_squad
- ruler_qa_hotpot
# - ruler_cwe
# - ruler_fwe
# - ruler_qa_squad
# - ruler_qa_hotpot
aggregate_metric_list:
- metric: acc
weight_by_size: False
......
......@@ -17,14 +17,12 @@
import itertools
import random
import string
from functools import cache
import datasets
import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer
from lm_eval.tasks.ruler.utils import SEQ_LENGTHS
from lm_eval.tasks.ruler.common_utils import SEQ_LENGTHS, get_tokenizer
TASKS = {
"variable_tracking": {
......@@ -40,11 +38,6 @@ TEMPLATE = (
)
@cache
def get_tokenizer(pretrained):
return AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
def generate_chains(num_chains, num_hops, is_icl=False):
vars_all = []
k = 5 if not is_icl else 3
......@@ -161,17 +154,19 @@ def sys_vartrack_w_noise_random(
)
# Calculate the number of tokens in the example
total_tokens = len(TOKENIZER(input_text + f" {answer}").input_ids)
print(
f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate + example_tokens} | Noises: {num_noises}"
)
# print(
# f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate + example_tokens} | Noises: {num_noises}"
# )
if total_tokens + tokens_to_generate + example_tokens > max_seq_length:
num_noises -= incremental
break
num_noises += incremental
print("Num noises:", num_noises)
# print("Num noises:", num_noises)
# Generate samples
for index in tqdm(range(num_samples)):
for index in tqdm(
range(num_samples), desc=f"Generating VT Samples| {max_seq_length}"
):
used_noises = num_noises
while True:
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment