Commit 7ab8b057 authored by Baber's avatar Baber
Browse files

add gen_prefix

parent ebccca1e
......@@ -32,8 +32,8 @@ SEQ_LENGTHS = (
# 131072,
# 65536,
# 32768,
16384,
8192,
# 16384,
# 8192,
4096,
)
......@@ -61,7 +61,7 @@ def get_haystack(
return haystack
def flatten(df: Generator) -> dict[str, datasets.Dataset]:
def download_dataset(df: Generator) -> dict[str, datasets.Dataset]:
return {
"test": datasets.Dataset.from_list(
list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST
......@@ -70,7 +70,7 @@ def flatten(df: Generator) -> dict[str, datasets.Dataset]:
# ruff: noqa
niah_single_1 = lambda **kwargs: flatten(
niah_single_1 = lambda **kwargs: download_dataset(
generate_samples(
get_haystack(type_haystack="repeat"),
max_seq_length=seq,
......@@ -83,7 +83,7 @@ niah_single_1 = lambda **kwargs: flatten(
for seq in SEQ_LENGTHS
)
# ruff: noqa
niah_single_2 = lambda **kwargs: flatten(
niah_single_2 = lambda **kwargs: download_dataset(
generate_samples(
get_haystack(type_haystack="essay"),
max_seq_length=seq,
......@@ -96,7 +96,7 @@ niah_single_2 = lambda **kwargs: flatten(
for seq in SEQ_LENGTHS
)
# noqa
niah_single_3 = lambda **kwargs: flatten(
niah_single_3 = lambda **kwargs: download_dataset(
generate_samples(
get_haystack(type_haystack="essay"),
max_seq_length=seq,
......@@ -109,7 +109,7 @@ niah_single_3 = lambda **kwargs: flatten(
for seq in SEQ_LENGTHS
)
# noqa
niah_multikey_1 = lambda **kwargs: flatten(
niah_multikey_1 = lambda **kwargs: download_dataset(
generate_samples(
get_haystack(type_haystack="essay"),
max_seq_length=seq,
......@@ -123,7 +123,7 @@ niah_multikey_1 = lambda **kwargs: flatten(
for seq in SEQ_LENGTHS
)
# noqa
niah_multikey_2 = lambda **kwargs: flatten(
niah_multikey_2 = lambda **kwargs: download_dataset(
generate_samples(
get_haystack(type_haystack="needle"),
max_seq_length=seq,
......@@ -136,7 +136,7 @@ niah_multikey_2 = lambda **kwargs: flatten(
for seq in SEQ_LENGTHS
)
# noqa
niah_multikey_3 = lambda **kwargs: flatten(
niah_multikey_3 = lambda **kwargs: download_dataset(
generate_samples(
get_haystack(type_haystack="needle"),
max_seq_length=seq,
......@@ -149,7 +149,7 @@ niah_multikey_3 = lambda **kwargs: flatten(
for seq in SEQ_LENGTHS
)
# noqa
niah_multivalue = lambda **kwargs: flatten(
niah_multivalue = lambda **kwargs: download_dataset(
generate_samples(
get_haystack(type_haystack="essay"),
max_seq_length=seq,
......@@ -163,7 +163,7 @@ niah_multivalue = lambda **kwargs: flatten(
for seq in SEQ_LENGTHS
)
# noqa
niah_multiquery = lambda **kwargs: flatten(
niah_multiquery = lambda **kwargs: download_dataset(
generate_samples(
get_haystack(type_haystack="essay"),
max_seq_length=seq,
......
......@@ -149,7 +149,10 @@ def sys_vartrack_w_noise_random(
example_tokens = 0
if add_fewshot and (icl_example is not None):
icl_example_out = " ".join(icl_example["outputs"])
icl_example = icl_example["input"] + " " + icl_example_out + "\n\n"
prefix = icl_example["gen_prefix"]
icl_example = (
icl_example["input"] + " " + prefix + " " + icl_example_out + "\n\n"
)
example_tokens = len(TOKENIZER(icl_example).input_ids)
while total_tokens + tokens_to_generate + example_tokens < max_seq_length:
......@@ -204,12 +207,16 @@ def sys_vartrack_w_noise_random(
input_text.replace("\n", " ").replace("\t", " ").strip().split()
)
gen_prefix_index = input_text.rfind(" Answer")
gen_prefix = input_text[gen_prefix_index:].strip()
input_text = input_text[:gen_prefix_index]
formatted_output = {
"index": index,
"input": input_text,
"outputs": answer,
"length": length,
"max_length": max_seq_length,
"gen_prefix": gen_prefix,
}
write_jsons.append(formatted_output)
......@@ -230,7 +237,9 @@ def get_dataset(pretrained, seq=None, **kwargs):
return write_jsons
def get_vt_dataset(pretrained=None):
def get_vt_dataset(**kwargs):
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = (get_dataset(pretrained, seq=seq) for seq in SEQ_LENGTHS)
return {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment