Commit 3bbc73d4 authored by Baber's avatar Baber
Browse files

cleanup

parent 2c8a1656
import re
from functools import cache
from typing import TYPE_CHECKING, Union
from transformers import AutoTokenizer
if TYPE_CHECKING:
import transformers
SEQ_LENGTHS = (
# 131072,
# 65536,
# 32768,
# 16384,
# 8192,
4096,
)
@cache
def get_tokenizer(
tokenizer=None, pretrained=None, **kwargs
) -> Union["transformers.PreTrainedTokenizer", "transformers.PreTrainedTokenizerFast"]:
pretrained = tokenizer or pretrained
assert pretrained, "No tokenizer or pretrained provided."
print("using tokenizer ", pretrained)
return AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
def postprocess_pred(predict_str: str) -> str:
predict_str = predict_str.strip()
# Remove all non-printable characters
np_pattern = re.compile(r"[\x00-\x1f]")
predict_str = np_pattern.sub("\n", predict_str).strip()
return predict_str
def string_match_all(preds: list[str], refs: list[list[str]]) -> float:
score = sum(
[
sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) / len(ref)
for pred, ref in zip(preds, refs)
]
) / len(preds)
return score
def string_match_part(preds: list[str], refs: list[list[str]]) -> float:
score = max(
[
sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) / len(ref)
for pred, ref in zip(preds, refs)
]
) / len(preds)
return score
def process_results(doc: dict, results: list[str]) -> dict[str, float]:
# hacky: set all other lengths to -1
metrics = {str(length): -1.0 for length in SEQ_LENGTHS}
input_len = doc["max_length"]
pred = postprocess_pred(results[0])
score = string_match_all([pred], [doc["outputs"]])
metrics[str(input_len)] = score
return metrics
def process_results_part(doc: dict, results: list[str]) -> dict[str, float]:
# hacky: set all other lengths to -1
metrics = {str(length): -1.0 for length in SEQ_LENGTHS}
input_len = doc["max_length"]
pred = postprocess_pred(results[0])
score = string_match_part([pred], [doc["outputs"]])
metrics[str(input_len)] = score
return metrics
def aggregate_metrics(metrics: list[float]) -> float:
res = [x for x in metrics if x != -1]
if not res:
# we don't have any samples with this length
return 0.0
return sum(res) / len(res)
...@@ -13,14 +13,12 @@ ...@@ -13,14 +13,12 @@
# limitations under the License # limitations under the License
import itertools import itertools
import random import random
from functools import cache
import datasets import datasets
import wonderwords import wonderwords
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoTokenizer
from lm_eval.tasks.ruler.utils import SEQ_LENGTHS from lm_eval.tasks.ruler.common_utils import SEQ_LENGTHS, get_tokenizer
RNG = random.Random(42) RNG = random.Random(42)
...@@ -107,9 +105,9 @@ def sys_word_pair_random( ...@@ -107,9 +105,9 @@ def sys_word_pair_random(
+ " ".join([f"{i + 1}. {word}" for i, word in enumerate(answer)]) + " ".join([f"{i + 1}. {word}" for i, word in enumerate(answer)])
).input_ids ).input_ids
) )
print( # print(
f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Words: {num_words}" # f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Words: {num_words}"
) # )
if total_tokens + tokens_to_generate > max_seq_length: if total_tokens + tokens_to_generate > max_seq_length:
num_words -= incremental num_words -= incremental
break break
...@@ -119,10 +117,12 @@ def sys_word_pair_random( ...@@ -119,10 +117,12 @@ def sys_word_pair_random(
num_words = len(WORDS) num_words = len(WORDS)
break break
print("num_words:", num_words) # print("num_words:", num_words)
# Generate samples # Generate samples
for index in tqdm(range(num_samples)): for index in tqdm(
range(num_samples), desc=f"Generating CWE Samples | {max_seq_length}"
):
used_words = num_words used_words = num_words
while True: while True:
try: try:
...@@ -161,11 +161,6 @@ def sys_word_pair_random( ...@@ -161,11 +161,6 @@ def sys_word_pair_random(
return write_jsons return write_jsons
@cache
def get_tokenizer(pretrained):
return AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
def get_dataset(pretrained, seq=None, **kwargs): def get_dataset(pretrained, seq=None, **kwargs):
tokenizer = get_tokenizer(pretrained) tokenizer = get_tokenizer(pretrained)
write_jsons = sys_word_pair_random( write_jsons = sys_word_pair_random(
......
...@@ -23,12 +23,14 @@ from bs4 import BeautifulSoup ...@@ -23,12 +23,14 @@ from bs4 import BeautifulSoup
from tqdm.asyncio import tqdm as async_tqdm from tqdm.asyncio import tqdm as async_tqdm
@cache
async def fetch_url(client: httpx.AsyncClient, url: str) -> str: async def fetch_url(client: httpx.AsyncClient, url: str) -> str:
response = await client.get(url) response = await client.get(url)
response.raise_for_status() response.raise_for_status()
return response.text return response.text
@cache
async def process_html_essay( async def process_html_essay(
client: httpx.AsyncClient, url: str, h: html2text.HTML2Text, temp_folder: str client: httpx.AsyncClient, url: str, h: html2text.HTML2Text, temp_folder: str
) -> None: ) -> None:
...@@ -50,6 +52,7 @@ async def process_html_essay( ...@@ -50,6 +52,7 @@ async def process_html_essay(
print(f"Failed to download {filename}: {str(e)}") print(f"Failed to download {filename}: {str(e)}")
@cache
async def process_text_essay( async def process_text_essay(
client: httpx.AsyncClient, url: str, temp_folder: str client: httpx.AsyncClient, url: str, temp_folder: str
) -> None: ) -> None:
...@@ -64,6 +67,7 @@ async def process_text_essay( ...@@ -64,6 +67,7 @@ async def process_text_essay(
print(f"Failed to download {filename}: {str(e)}") print(f"Failed to download {filename}: {str(e)}")
@cache
async def get_essays() -> Dict[str, str]: async def get_essays() -> Dict[str, str]:
temp_folder_repo = "essay_repo" temp_folder_repo = "essay_repo"
temp_folder_html = "essay_html" temp_folder_html = "essay_html"
......
...@@ -14,16 +14,14 @@ ...@@ -14,16 +14,14 @@
import itertools import itertools
import random import random
import string import string
from functools import cache
import datasets import datasets
import numpy as np import numpy as np
import transformers import transformers
from scipy.special import zeta from scipy.special import zeta
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoTokenizer
from lm_eval.tasks.ruler.utils import SEQ_LENGTHS from lm_eval.tasks.ruler.common_utils import SEQ_LENGTHS, get_tokenizer
config = ( config = (
...@@ -112,9 +110,11 @@ def sys_kwext( ...@@ -112,9 +110,11 @@ def sys_kwext(
incremental=input_max_len // 32, incremental=input_max_len // 32,
alpha=alpha, alpha=alpha,
) )
print("num_example_words:", num_example_words) # print("num_example_words:", num_example_words)
# Generate samples # Generate samples
for index in tqdm(range(num_samples)): for index in tqdm(
range(num_samples), desc=f"Generating FWE Samples | {max_seq_length}"
):
# construct input # construct input
input_max_len = max_seq_length input_max_len = max_seq_length
input_text, answer, _ = generate_input_output( input_text, answer, _ = generate_input_output(
...@@ -139,7 +139,7 @@ def sys_kwext( ...@@ -139,7 +139,7 @@ def sys_kwext(
"input": input_text, "input": input_text,
"outputs": answer, "outputs": answer,
"length": length, "length": length,
"max_seq_length": max_seq_length, "max_length": max_seq_length,
"gen_prefix": "Answer: According to the coded text above, the three most frequently appeared words are:", "gen_prefix": "Answer: According to the coded text above, the three most frequently appeared words are:",
} }
write_jsons.append(formatted_output) write_jsons.append(formatted_output)
...@@ -147,11 +147,6 @@ def sys_kwext( ...@@ -147,11 +147,6 @@ def sys_kwext(
return write_jsons return write_jsons
@cache
def get_tokenizer(pretrained):
return AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
def get_dataset(pretrained, max_seq_length=None, **kwargs): def get_dataset(pretrained, max_seq_length=None, **kwargs):
tokenizer = get_tokenizer(pretrained) tokenizer = get_tokenizer(pretrained)
write_jsons = sys_kwext( write_jsons = sys_kwext(
......
task: niah_multikey_1 task: niah_multikey_1
include: niah_single_1.yaml include: niah_single_1.yaml
download_dataset: !function utils.niah_multikey_1 download_dataset: !function niah_utils.niah_multikey_1
task: niah_multikey_2 task: niah_multikey_2
include: niah_single_1.yaml include: niah_single_1.yaml
download_dataset: !function utils.niah_multikey_2 download_dataset: !function niah_utils.niah_multikey_2
task: niah_multikey_3 task: niah_multikey_3
include: niah_single_1.yaml include: niah_single_1.yaml
download_dataset: !function utils.niah_multikey_3 download_dataset: !function niah_utils.niah_multikey_3
task: niah_multiquery task: niah_multiquery
include: niah_single_1.yaml include: niah_single_1.yaml
download_dataset: !function utils.niah_multiquery download_dataset: !function niah_utils.niah_multiquery
task: niah_multivalue task: niah_multivalue
include: niah_single_1.yaml include: niah_single_1.yaml
download_dataset: !function utils.niah_multivalue download_dataset: !function niah_utils.niah_multivalue
...@@ -5,20 +5,20 @@ dataset_path: "" ...@@ -5,20 +5,20 @@ dataset_path: ""
dataset_name: "" dataset_name: ""
output_type: generate_until output_type: generate_until
test_split: test test_split: test
download_dataset: !function utils.niah_single_1 download_dataset: !function niah_utils.niah_single_1
doc_to_text: "{{input}}" doc_to_text: "{{input}}"
doc_to_target: "{{outputs[0]}}" doc_to_target: "{{outputs[0]}}"
gen_prefix: "{{gen_prefix}}" gen_prefix: "{{gen_prefix}}"
process_results: !function utils.process_results process_results: !function common_utils.process_results
metric_list: metric_list:
- metric: "4096" - metric: "4096"
aggregation: !function utils.aggregate_metrics aggregation: !function common_utils.aggregate_metrics
higher_is_better: true higher_is_better: true
- metric: "8192" - metric: "8192"
aggregation: !function utils.aggregate_metrics aggregation: !function common_utils.aggregate_metrics
higher_is_better: true higher_is_better: true
- metric: "16384" - metric: "16384"
aggregation: !function utils.aggregate_metrics aggregation: !function common_utils.aggregate_metrics
higher_is_better: true higher_is_better: true
# - metric: "32768" # - metric: "32768"
# aggregation: !function utils.aggregate_metrics # aggregation: !function utils.aggregate_metrics
......
task: niah_single_2 task: niah_single_2
include: niah_single_1.yaml include: niah_single_1.yaml
download_dataset: !function utils.niah_single_2 download_dataset: !function niah_utils.niah_single_2
task: niah_single_3 task: niah_single_3
include: niah_single_1.yaml include: niah_single_1.yaml
download_dataset: !function utils.niah_single_3 download_dataset: !function niah_utils.niah_single_3
# noqa
import itertools import itertools
import re from typing import Literal, Union, Generator
from functools import cache
from typing import Literal, Generator, Union, TYPE_CHECKING
import datasets import datasets
from transformers import AutoTokenizer
from lm_eval.tasks.ruler.essays import get_all_essays
from lm_eval.tasks.ruler.prepare import generate_samples
if TYPE_CHECKING:
import transformers
from lm_eval.tasks.ruler.prepare import generate_samples, get_haystack
from lm_eval.tasks.ruler.common_utils import SEQ_LENGTHS, get_tokenizer
def get_tokenizer(
**kwargs,
) -> Union["transformers.PreTrainedTokenizer", "transformers.PreTrainedTokenizerFast"]:
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
assert pretrained, "No tokenizer or pretrained provided."
print("using tokenizer ", pretrained)
return AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
# TEMPLATE = """Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text? The special magic {type_needle_v} for {query} mentioned in the provided text are"""
TEMPLATE = """Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text?""" TEMPLATE = """Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text?"""
SEQ_LENGTHS = (
# 131072,
# 65536,
# 32768,
# 16384,
# 8192,
4096,
)
NUM_SAMPLES = 500
REMOVE_NEWLINE_TAB = ""
STOP_WORDS = ""
RANDOM_SEED = 42
@cache
def get_haystack(
type_haystack: Literal["essay", "repeat", "needle"],
) -> Union[list[str], str]:
NEEDLE = "One of the special magic {type_needle_v} for {key} is: {value}."
if type_haystack == "essay":
essay = get_all_essays()["text"]
# essay = json.load(open(essay))["text"]
haystack = re.sub(r"\s+", " ", essay).split(" ")
elif type_haystack == "repeat":
haystack = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
elif type_haystack == "needle":
haystack = NEEDLE
else:
raise NotImplementedError(f"{type_haystack} is not implemented.")
return haystack
def download_dataset(df: Generator) -> dict[str, datasets.Dataset]: def download_dataset(df: Generator) -> dict[str, datasets.Dataset]:
return { return {
"test": datasets.Dataset.from_list( "test": datasets.Dataset.from_list(
...@@ -78,7 +26,7 @@ niah_single_1 = lambda **kwargs: download_dataset( ...@@ -78,7 +26,7 @@ niah_single_1 = lambda **kwargs: download_dataset(
type_haystack="repeat", type_haystack="repeat",
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
TOKENIZER=get_tokenizer(**kwargs), TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
...@@ -91,7 +39,7 @@ niah_single_2 = lambda **kwargs: download_dataset( ...@@ -91,7 +39,7 @@ niah_single_2 = lambda **kwargs: download_dataset(
type_haystack="essay", type_haystack="essay",
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
TOKENIZER=get_tokenizer(**kwargs), TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
...@@ -104,7 +52,7 @@ niah_single_3 = lambda **kwargs: download_dataset( ...@@ -104,7 +52,7 @@ niah_single_3 = lambda **kwargs: download_dataset(
type_haystack="essay", type_haystack="essay",
type_needle_k="words", type_needle_k="words",
type_needle_v="uuids", type_needle_v="uuids",
TOKENIZER=get_tokenizer(**kwargs), TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
...@@ -118,7 +66,7 @@ niah_multikey_1 = lambda **kwargs: download_dataset( ...@@ -118,7 +66,7 @@ niah_multikey_1 = lambda **kwargs: download_dataset(
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
num_needle_k=4, num_needle_k=4,
TOKENIZER=get_tokenizer(**kwargs), TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
...@@ -131,7 +79,7 @@ niah_multikey_2 = lambda **kwargs: download_dataset( ...@@ -131,7 +79,7 @@ niah_multikey_2 = lambda **kwargs: download_dataset(
type_haystack="needle", type_haystack="needle",
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
TOKENIZER=get_tokenizer(**kwargs), TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
...@@ -144,7 +92,7 @@ niah_multikey_3 = lambda **kwargs: download_dataset( ...@@ -144,7 +92,7 @@ niah_multikey_3 = lambda **kwargs: download_dataset(
type_haystack="needle", type_haystack="needle",
type_needle_k="uuids", type_needle_k="uuids",
type_needle_v="uuids", type_needle_v="uuids",
TOKENIZER=get_tokenizer(**kwargs), TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
...@@ -158,7 +106,7 @@ niah_multivalue = lambda **kwargs: download_dataset( ...@@ -158,7 +106,7 @@ niah_multivalue = lambda **kwargs: download_dataset(
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
num_needle_v=4, num_needle_v=4,
TOKENIZER=get_tokenizer(**kwargs), TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
...@@ -172,45 +120,7 @@ niah_multiquery = lambda **kwargs: download_dataset( ...@@ -172,45 +120,7 @@ niah_multiquery = lambda **kwargs: download_dataset(
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
num_needle_q=4, num_needle_q=4,
TOKENIZER=get_tokenizer(**kwargs), TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
def postprocess_pred(predict_str: str) -> str:
predict_str = predict_str.strip()
# Remove all non-printable characters
np_pattern = re.compile(r"[\x00-\x1f]")
predict_str = np_pattern.sub("\n", predict_str).strip()
return predict_str
def string_match_all(preds: list[str], refs: list[list[str]]) -> float:
score = sum(
[
sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) / len(ref)
for pred, ref in zip(preds, refs)
]
) / len(preds)
return score
def process_results(doc: dict, results: list[str]) -> dict[str, float]:
# hacky: set all other lengths to -1
metrics = {str(length): -1.0 for length in SEQ_LENGTHS}
input_len = doc["max_length"]
pred = postprocess_pred(results[0])
score = string_match_all([pred], [doc["outputs"]])
metrics[str(input_len)] = score
return metrics
def aggregate_metrics(metrics: list[float]) -> float:
res = [x for x in metrics if x != -1]
if not res:
# we don't have any samples with this length
return 0.0
return sum(res) / len(res)
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
import os import os
import random import random
import re
import uuid import uuid
from linecache import cache from functools import lru_cache, cache
from functools import lru_cache from typing import List, Union, Literal
from typing import List, Union
import numpy as np import numpy as np
import wonderwords import wonderwords
...@@ -29,6 +29,7 @@ from importlib.metadata import version ...@@ -29,6 +29,7 @@ from importlib.metadata import version
from tqdm import tqdm from tqdm import tqdm
from lm_eval.tasks.ruler.essays import get_all_essays
NUM_SAMPLES = 500 NUM_SAMPLES = 500
REMOVE_NEWLINE_TAB = "" REMOVE_NEWLINE_TAB = ""
...@@ -317,3 +318,21 @@ def generate_samples( ...@@ -317,3 +318,21 @@ def generate_samples(
) )
write_jsons.append(formatted_output) write_jsons.append(formatted_output)
return write_jsons return write_jsons
@cache
def get_haystack(
type_haystack: Literal["essay", "repeat", "needle"],
) -> Union[list[str], str]:
NEEDLE = "One of the special magic {type_needle_v} for {key} is: {value}."
if type_haystack == "essay":
essay = get_all_essays()["text"]
# essay = json.load(open(essay))["text"]
haystack = re.sub(r"\s+", " ", essay).split(" ")
elif type_haystack == "repeat":
haystack = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
elif type_haystack == "needle":
haystack = NEEDLE
else:
raise NotImplementedError(f"{type_haystack} is not implemented.")
return haystack
include: niah_single_1.yaml include: niah_single_1.yaml
task: ruler_qa_squad task: ruler_qa_squad
download_dataset: !function qa_utils.get_squad download_dataset: !function qa_utils.get_squad
process_results: !function common_utils.process_results_part
test_split: test test_split: test
generation_kwargs: generation_kwargs:
do_sample: false do_sample: false
......
...@@ -20,9 +20,8 @@ from functools import cache ...@@ -20,9 +20,8 @@ from functools import cache
import datasets import datasets
import requests import requests
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoTokenizer
from lm_eval.tasks.ruler.utils import SEQ_LENGTHS from lm_eval.tasks.ruler.common_utils import SEQ_LENGTHS, get_tokenizer
config = { config = {
"tokens_to_generate": 32, "tokens_to_generate": 32,
...@@ -154,9 +153,9 @@ def generate_samples( ...@@ -154,9 +153,9 @@ def generate_samples(
input_text, answer = generate_input_output(0, num_docs, qas=qas, docs=docs) input_text, answer = generate_input_output(0, num_docs, qas=qas, docs=docs)
# Calculate the number of tokens in the example # Calculate the number of tokens in the example
total_tokens = len(tokenizer(input_text + f" {answer}").input_ids) total_tokens = len(tokenizer(input_text + f" {answer}").input_ids)
print( # print(
f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Docs: {num_docs}" # f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Docs: {num_docs}"
) # )
if total_tokens + tokens_to_generate > max_seq_length: if total_tokens + tokens_to_generate > max_seq_length:
num_docs -= incremental num_docs -= incremental
break break
...@@ -165,10 +164,12 @@ def generate_samples( ...@@ -165,10 +164,12 @@ def generate_samples(
if num_docs > len(docs): if num_docs > len(docs):
num_docs = len(docs) num_docs = len(docs)
break break
print("Number of documents:", num_docs) # print("Number of documents:", num_docs)
# Generate samples # Generate samples
for index in tqdm(range(num_samples)): for index in tqdm(
range(num_samples), desc=f"Generating QA Samples | {max_seq_length}"
):
used_docs = num_docs used_docs = num_docs
while True: while True:
try: try:
...@@ -192,7 +193,7 @@ def generate_samples( ...@@ -192,7 +193,7 @@ def generate_samples(
"input": input_text, "input": input_text,
"outputs": answer, "outputs": answer,
"length": length, "length": length,
"max_seq_length": max_seq_length, "max_length": max_seq_length,
"gen_prefix": "Answer:", "gen_prefix": "Answer:",
} }
write_jsons.append(formatted_output) write_jsons.append(formatted_output)
...@@ -200,11 +201,6 @@ def generate_samples( ...@@ -200,11 +201,6 @@ def generate_samples(
return write_jsons return write_jsons
@cache
def get_tokenizer(pretrained):
return AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
def get_dataset(pretrained, docs, qas, max_seq_length=None, **kwargs): def get_dataset(pretrained, docs, qas, max_seq_length=None, **kwargs):
tokenizer = get_tokenizer(pretrained) tokenizer = get_tokenizer(pretrained)
write_jsons = generate_samples( write_jsons = generate_samples(
...@@ -226,7 +222,7 @@ def get_qa_dataset(ds, **kwargs): ...@@ -226,7 +222,7 @@ def get_qa_dataset(ds, **kwargs):
else: else:
qas, docs = read_hotpotqa() qas, docs = read_hotpotqa()
df = ( df = (
get_dataset(pretrained, docs=docs, qas=qas, max_seq_length=seq) get_dataset(pretrained=pretrained, docs=docs, qas=qas, max_seq_length=seq)
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
...@@ -243,7 +239,3 @@ def get_squad(**kwargs): ...@@ -243,7 +239,3 @@ def get_squad(**kwargs):
def get_hotpotqa(**kwargs): def get_hotpotqa(**kwargs):
return get_qa_dataset("hotpotqa", **kwargs) return get_qa_dataset("hotpotqa", **kwargs)
# get_squad = lambda **kwargs: partial(get_qa_dataset, "squad")(**kwargs)
# get_hotpotqa = lambda **kwargs: partial(get_qa_dataset, "hotpotqa")(**kwargs)
group: ruler group: ruler
task: task:
- niah_single_1 # - niah_single_1
- niah_single_2 # - niah_single_2
- niah_single_3 # - niah_single_3
- niah_multikey_1 # - niah_multikey_1
- niah_multikey_2 # - niah_multikey_2
- niah_multikey_3 # - niah_multikey_3
- niah_multiquery # - niah_multiquery
- niah_multivalue # - niah_multivalue
- ruler_vt - ruler_vt
- ruler_cwe # - ruler_cwe
- ruler_fwe # - ruler_fwe
- ruler_qa_squad # - ruler_qa_squad
- ruler_qa_hotpot # - ruler_qa_hotpot
aggregate_metric_list: aggregate_metric_list:
- metric: acc - metric: acc
weight_by_size: False weight_by_size: False
......
...@@ -17,14 +17,12 @@ ...@@ -17,14 +17,12 @@
import itertools import itertools
import random import random
import string import string
from functools import cache
import datasets import datasets
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoTokenizer
from lm_eval.tasks.ruler.utils import SEQ_LENGTHS from lm_eval.tasks.ruler.common_utils import SEQ_LENGTHS, get_tokenizer
TASKS = { TASKS = {
"variable_tracking": { "variable_tracking": {
...@@ -40,11 +38,6 @@ TEMPLATE = ( ...@@ -40,11 +38,6 @@ TEMPLATE = (
) )
@cache
def get_tokenizer(pretrained):
return AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
def generate_chains(num_chains, num_hops, is_icl=False): def generate_chains(num_chains, num_hops, is_icl=False):
vars_all = [] vars_all = []
k = 5 if not is_icl else 3 k = 5 if not is_icl else 3
...@@ -161,17 +154,19 @@ def sys_vartrack_w_noise_random( ...@@ -161,17 +154,19 @@ def sys_vartrack_w_noise_random(
) )
# Calculate the number of tokens in the example # Calculate the number of tokens in the example
total_tokens = len(TOKENIZER(input_text + f" {answer}").input_ids) total_tokens = len(TOKENIZER(input_text + f" {answer}").input_ids)
print( # print(
f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate + example_tokens} | Noises: {num_noises}" # f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate + example_tokens} | Noises: {num_noises}"
) # )
if total_tokens + tokens_to_generate + example_tokens > max_seq_length: if total_tokens + tokens_to_generate + example_tokens > max_seq_length:
num_noises -= incremental num_noises -= incremental
break break
num_noises += incremental num_noises += incremental
print("Num noises:", num_noises) # print("Num noises:", num_noises)
# Generate samples # Generate samples
for index in tqdm(range(num_samples)): for index in tqdm(
range(num_samples), desc=f"Generating VT Samples| {max_seq_length}"
):
used_noises = num_noises used_noises = num_noises
while True: while True:
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment