Commit bd34e852 authored by Baber's avatar Baber
Browse files

fix tokenizer

parent 7399dae1
...@@ -13,10 +13,8 @@ from packaging.version import parse as parse_version ...@@ -13,10 +13,8 @@ from packaging.version import parse as parse_version
from importlib.metadata import version from importlib.metadata import version
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoTokenizer
TOKENIZER = AutoTokenizer.from_pretrained(os.environ.get("TOKENIZER"))
COUNT = 0 COUNT = 0
NUM_SAMPLES = 500 NUM_SAMPLES = 500
...@@ -209,8 +207,10 @@ def generate_samples( ...@@ -209,8 +207,10 @@ def generate_samples(
incremental: int = 500, incremental: int = 500,
remove_newline_tab: bool = False, remove_newline_tab: bool = False,
random_seed: int = 42, random_seed: int = 42,
TOKENIZER=None,
): ):
global COUNT assert TOKENIZER is not None, "TOKENIZER is not defined."
print("using tokenizer ", TOKENIZER)
num_needle_k = max(num_needle_k, num_needle_q) num_needle_k = max(num_needle_k, num_needle_q)
write_jsons = [] write_jsons = []
tokens_to_generate = tokens_to_generate tokens_to_generate = tokens_to_generate
......
...@@ -13,7 +13,12 @@ from lm_eval.tasks.ruler.essays import get_essays, get_all_essays ...@@ -13,7 +13,12 @@ from lm_eval.tasks.ruler.essays import get_essays, get_all_essays
from lm_eval.tasks.ruler.prepare import generate_samples from lm_eval.tasks.ruler.prepare import generate_samples
TOKENIZER = AutoTokenizer.from_pretrained(os.environ.get("TOKENIZER")) @cache
def get_tokenizer():
return AutoTokenizer.from_pretrained(os.environ.get("TOKENIZER"))
# TOKENIZER = AutoTokenizer.from_pretrained(os.environ.get("TOKENIZER"))
TEMPLATE = """Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text?""" TEMPLATE = """Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text?"""
SEQ_LENGTHS = ( SEQ_LENGTHS = (
...@@ -64,6 +69,7 @@ niah_single_1 = lambda: flatten( ...@@ -64,6 +69,7 @@ niah_single_1 = lambda: flatten(
type_haystack="repeat", type_haystack="repeat",
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
TOKENIZER=get_tokenizer(),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
...@@ -76,6 +82,7 @@ niah_single_2 = lambda: flatten( ...@@ -76,6 +82,7 @@ niah_single_2 = lambda: flatten(
type_haystack="essay", type_haystack="essay",
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
TOKENIZER=get_tokenizer(),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
...@@ -88,6 +95,7 @@ niah_single_3 = lambda: flatten( ...@@ -88,6 +95,7 @@ niah_single_3 = lambda: flatten(
type_haystack="essay", type_haystack="essay",
type_needle_k="words", type_needle_k="words",
type_needle_v="uuids", type_needle_v="uuids",
TOKENIZER=get_tokenizer(),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
...@@ -101,6 +109,7 @@ niah_multikey_1 = lambda: flatten( ...@@ -101,6 +109,7 @@ niah_multikey_1 = lambda: flatten(
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
num_needle_k=4, num_needle_k=4,
TOKENIZER=get_tokenizer(),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
...@@ -113,6 +122,7 @@ niah_multikey_2 = lambda: flatten( ...@@ -113,6 +122,7 @@ niah_multikey_2 = lambda: flatten(
type_haystack="needle", type_haystack="needle",
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
TOKENIZER=get_tokenizer(),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
...@@ -125,6 +135,7 @@ niah_multikey_3 = lambda: flatten( ...@@ -125,6 +135,7 @@ niah_multikey_3 = lambda: flatten(
type_haystack="needle", type_haystack="needle",
type_needle_k="uuids", type_needle_k="uuids",
type_needle_v="uuids", type_needle_v="uuids",
TOKENIZER=get_tokenizer(),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
...@@ -138,6 +149,7 @@ niah_multivalue = lambda: flatten( ...@@ -138,6 +149,7 @@ niah_multivalue = lambda: flatten(
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
num_needle_v=4, num_needle_v=4,
TOKENIZER=get_tokenizer(),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
...@@ -151,6 +163,7 @@ niah_multiquery = lambda: flatten( ...@@ -151,6 +163,7 @@ niah_multiquery = lambda: flatten(
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
num_needle_q=4, num_needle_q=4,
TOKENIZER=get_tokenizer(),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment