Commit 7ab8b057 authored by Baber's avatar Baber
Browse files

add gen_prefix

parent ebccca1e
...@@ -32,8 +32,8 @@ SEQ_LENGTHS = ( ...@@ -32,8 +32,8 @@ SEQ_LENGTHS = (
# 131072, # 131072,
# 65536, # 65536,
# 32768, # 32768,
16384, # 16384,
8192, # 8192,
4096, 4096,
) )
...@@ -61,7 +61,7 @@ def get_haystack( ...@@ -61,7 +61,7 @@ def get_haystack(
return haystack return haystack
def flatten(df: Generator) -> dict[str, datasets.Dataset]: def download_dataset(df: Generator) -> dict[str, datasets.Dataset]:
return { return {
"test": datasets.Dataset.from_list( "test": datasets.Dataset.from_list(
list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST
...@@ -70,7 +70,7 @@ def flatten(df: Generator) -> dict[str, datasets.Dataset]: ...@@ -70,7 +70,7 @@ def flatten(df: Generator) -> dict[str, datasets.Dataset]:
# ruff: noqa # ruff: noqa
niah_single_1 = lambda **kwargs: flatten( niah_single_1 = lambda **kwargs: download_dataset(
generate_samples( generate_samples(
get_haystack(type_haystack="repeat"), get_haystack(type_haystack="repeat"),
max_seq_length=seq, max_seq_length=seq,
...@@ -83,7 +83,7 @@ niah_single_1 = lambda **kwargs: flatten( ...@@ -83,7 +83,7 @@ niah_single_1 = lambda **kwargs: flatten(
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
# ruff: noqa # ruff: noqa
niah_single_2 = lambda **kwargs: flatten( niah_single_2 = lambda **kwargs: download_dataset(
generate_samples( generate_samples(
get_haystack(type_haystack="essay"), get_haystack(type_haystack="essay"),
max_seq_length=seq, max_seq_length=seq,
...@@ -96,7 +96,7 @@ niah_single_2 = lambda **kwargs: flatten( ...@@ -96,7 +96,7 @@ niah_single_2 = lambda **kwargs: flatten(
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
# noqa # noqa
niah_single_3 = lambda **kwargs: flatten( niah_single_3 = lambda **kwargs: download_dataset(
generate_samples( generate_samples(
get_haystack(type_haystack="essay"), get_haystack(type_haystack="essay"),
max_seq_length=seq, max_seq_length=seq,
...@@ -109,7 +109,7 @@ niah_single_3 = lambda **kwargs: flatten( ...@@ -109,7 +109,7 @@ niah_single_3 = lambda **kwargs: flatten(
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
# noqa # noqa
niah_multikey_1 = lambda **kwargs: flatten( niah_multikey_1 = lambda **kwargs: download_dataset(
generate_samples( generate_samples(
get_haystack(type_haystack="essay"), get_haystack(type_haystack="essay"),
max_seq_length=seq, max_seq_length=seq,
...@@ -123,7 +123,7 @@ niah_multikey_1 = lambda **kwargs: flatten( ...@@ -123,7 +123,7 @@ niah_multikey_1 = lambda **kwargs: flatten(
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
# noqa # noqa
niah_multikey_2 = lambda **kwargs: flatten( niah_multikey_2 = lambda **kwargs: download_dataset(
generate_samples( generate_samples(
get_haystack(type_haystack="needle"), get_haystack(type_haystack="needle"),
max_seq_length=seq, max_seq_length=seq,
...@@ -136,7 +136,7 @@ niah_multikey_2 = lambda **kwargs: flatten( ...@@ -136,7 +136,7 @@ niah_multikey_2 = lambda **kwargs: flatten(
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
# noqa # noqa
niah_multikey_3 = lambda **kwargs: flatten( niah_multikey_3 = lambda **kwargs: download_dataset(
generate_samples( generate_samples(
get_haystack(type_haystack="needle"), get_haystack(type_haystack="needle"),
max_seq_length=seq, max_seq_length=seq,
...@@ -149,7 +149,7 @@ niah_multikey_3 = lambda **kwargs: flatten( ...@@ -149,7 +149,7 @@ niah_multikey_3 = lambda **kwargs: flatten(
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
# noqa # noqa
niah_multivalue = lambda **kwargs: flatten( niah_multivalue = lambda **kwargs: download_dataset(
generate_samples( generate_samples(
get_haystack(type_haystack="essay"), get_haystack(type_haystack="essay"),
max_seq_length=seq, max_seq_length=seq,
...@@ -163,7 +163,7 @@ niah_multivalue = lambda **kwargs: flatten( ...@@ -163,7 +163,7 @@ niah_multivalue = lambda **kwargs: flatten(
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
# noqa # noqa
niah_multiquery = lambda **kwargs: flatten( niah_multiquery = lambda **kwargs: download_dataset(
generate_samples( generate_samples(
get_haystack(type_haystack="essay"), get_haystack(type_haystack="essay"),
max_seq_length=seq, max_seq_length=seq,
......
...@@ -149,7 +149,10 @@ def sys_vartrack_w_noise_random( ...@@ -149,7 +149,10 @@ def sys_vartrack_w_noise_random(
example_tokens = 0 example_tokens = 0
if add_fewshot and (icl_example is not None): if add_fewshot and (icl_example is not None):
icl_example_out = " ".join(icl_example["outputs"]) icl_example_out = " ".join(icl_example["outputs"])
icl_example = icl_example["input"] + " " + icl_example_out + "\n\n" prefix = icl_example["gen_prefix"]
icl_example = (
icl_example["input"] + " " + prefix + " " + icl_example_out + "\n\n"
)
example_tokens = len(TOKENIZER(icl_example).input_ids) example_tokens = len(TOKENIZER(icl_example).input_ids)
while total_tokens + tokens_to_generate + example_tokens < max_seq_length: while total_tokens + tokens_to_generate + example_tokens < max_seq_length:
...@@ -204,12 +207,16 @@ def sys_vartrack_w_noise_random( ...@@ -204,12 +207,16 @@ def sys_vartrack_w_noise_random(
input_text.replace("\n", " ").replace("\t", " ").strip().split() input_text.replace("\n", " ").replace("\t", " ").strip().split()
) )
gen_prefix_index = input_text.rfind(" Answer")
gen_prefix = input_text[gen_prefix_index:].strip()
input_text = input_text[:gen_prefix_index]
formatted_output = { formatted_output = {
"index": index, "index": index,
"input": input_text, "input": input_text,
"outputs": answer, "outputs": answer,
"length": length, "length": length,
"max_length": max_seq_length, "max_length": max_seq_length,
"gen_prefix": gen_prefix,
} }
write_jsons.append(formatted_output) write_jsons.append(formatted_output)
...@@ -230,7 +237,9 @@ def get_dataset(pretrained, seq=None, **kwargs): ...@@ -230,7 +237,9 @@ def get_dataset(pretrained, seq=None, **kwargs):
return write_jsons return write_jsons
def get_vt_dataset(pretrained=None): def get_vt_dataset(**kwargs):
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = (get_dataset(pretrained, seq=seq) for seq in SEQ_LENGTHS) df = (get_dataset(pretrained, seq=seq) for seq in SEQ_LENGTHS)
return { return {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment