Commit bf64b46f authored by Baber's avatar Baber
Browse files

fix vt

parent 83ca7278
...@@ -24,17 +24,17 @@ from tqdm import tqdm ...@@ -24,17 +24,17 @@ from tqdm import tqdm
from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
TASKS = { TASKS = {
"variable_tracking": { "variable_tracking": {
"tokens_to_generate": 30, "tokens_to_generate": 30,
"template": """Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n{context}\nQuestion: Find all variables that are assigned the value {query} in the text above.""", "template": """Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n{context}\nQuestion: Find all variables that are assigned the value {query} in the text above.""",
"answer_prefix": """Answer: According to the chain(s) of variable assignment in the text above, {num_v} variables are assigned the value {query}, they are:""", "answer_prefix": """ Answer: According to the chain(s) of variable assignment in the text above, {num_v} variables are assgined the value {query}, they are: """,
}, },
} }
TEMPLATE = ( TEMPLATE = (
"""Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n{context}\nQuestion: Find all variables that are assigned the value {query} in the text above.""" TASKS["variable_tracking"]["template"] + TASKS["variable_tracking"]["answer_prefix"]
+ TASKS["variable_tracking"]["answer_prefix"]
) )
...@@ -63,9 +63,7 @@ def generate_chains( ...@@ -63,9 +63,7 @@ def generate_chains(
return vars_ret, chains_ret return vars_ret, chains_ret
def generate_input_output( def generate_input_output(num_noises, num_chains, num_hops, is_icl=False):
num_noises: int, num_chains: int, num_hops: int, is_icl: bool = False
) -> tuple[str, list[str]]:
vars, chains = generate_chains(num_chains, num_hops, is_icl=is_icl) vars, chains = generate_chains(num_chains, num_hops, is_icl=is_icl)
noise = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.\n" noise = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.\n"
...@@ -81,8 +79,7 @@ def generate_input_output( ...@@ -81,8 +79,7 @@ def generate_input_output(
assert len(sentences) > len(chains[0]), ( assert len(sentences) > len(chains[0]), (
"Noises too short, unable to generate data" "Noises too short, unable to generate data"
) )
# ruff: noqa except: # noqa: E722
except:
print("reduces chain length for not enough noises") print("reduces chain length for not enough noises")
chains = [chain[: len(sentences) - 1] for chain in chains] chains = [chain[: len(sentences) - 1] for chain in chains]
# sample random positions to insert variable assignment # sample random positions to insert variable assignment
...@@ -136,41 +133,37 @@ def sys_vartrack_w_noise_random( ...@@ -136,41 +133,37 @@ def sys_vartrack_w_noise_random(
tokens_to_generate=30, tokens_to_generate=30,
icl_example: dict = None, icl_example: dict = None,
remove_newline_tab=False, remove_newline_tab=False,
) -> list[dict]: ):
write_jsons = [] write_jsons = []
tokens_to_generate = tokens_to_generate
# Find the perfect num_noises # Find the perfect num_noises
num_noises = incremental num_noises = incremental
TOKENIZER = tokenizer
total_tokens = 0 # Track the total tokens generated for this example total_tokens = 0 # Track the total tokens generated for this example
example_tokens = 0 example_tokens = 0
if add_fewshot and (icl_example is not None): if add_fewshot and (icl_example is not None):
icl_example_out = " ".join(icl_example["outputs"]) icl_example_out = " ".join(icl_example["outputs"])
prefix = icl_example["gen_prefix"] icl_example = icl_example["input"] + " " + icl_example_out + "\n\n"
icl_example = ( example_tokens = len(tokenizer(icl_example).input_ids)
icl_example["input"] + " " + prefix + " " + icl_example_out + "\n\n"
)
example_tokens = len(TOKENIZER(icl_example).input_ids)
while total_tokens + tokens_to_generate + example_tokens < max_seq_length: while total_tokens + tokens_to_generate + example_tokens < max_seq_length:
input_text, answer = generate_input_output( input_text, answer = generate_input_output(
num_noises, num_chains, num_hops, is_icl=add_fewshot & (icl_example is None) num_noises, num_chains, num_hops, is_icl=add_fewshot & (icl_example is None)
) )
# Calculate the number of tokens in the example # Calculate the number of tokens in the example
total_tokens = len(TOKENIZER(input_text + f" {answer}").input_ids) total_tokens = len(tokenizer(input_text + f" {answer}").input_ids)
# print( print(
# f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate + example_tokens} | Noises: {num_noises}" f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate + example_tokens} | Noises: {num_noises}"
# ) )
if total_tokens + tokens_to_generate + example_tokens > max_seq_length: if total_tokens + tokens_to_generate + example_tokens > max_seq_length:
num_noises -= incremental num_noises -= incremental
break break
num_noises += incremental num_noises += incremental
# print("Num noises:", num_noises) print("Num noises:", num_noises)
# Generate samples # Generate samples
for index in tqdm( for index in tqdm(range(num_samples)):
range(num_samples), desc=f"Generating VT Samples| {max_seq_length}"
):
used_noises = num_noises used_noises = num_noises
while True: while True:
try: try:
...@@ -181,14 +174,13 @@ def sys_vartrack_w_noise_random( ...@@ -181,14 +174,13 @@ def sys_vartrack_w_noise_random(
is_icl=add_fewshot & (icl_example is None), is_icl=add_fewshot & (icl_example is None),
) )
length = ( length = (
len(TOKENIZER(input_text).input_ids) len(tokenizer(input_text).input_ids)
+ tokens_to_generate + tokens_to_generate
+ example_tokens + example_tokens
) )
assert length <= max_seq_length, f"{length} exceeds max_seq_length." assert length <= max_seq_length, f"{length} exceeds max_seq_length."
break break
# ruff: noqa except: # noqa: E722
except:
if used_noises > incremental: if used_noises > incremental:
used_noises -= incremental used_noises -= incremental
...@@ -206,8 +198,12 @@ def sys_vartrack_w_noise_random( ...@@ -206,8 +198,12 @@ def sys_vartrack_w_noise_random(
input_text.replace("\n", " ").replace("\t", " ").strip().split() input_text.replace("\n", " ").replace("\t", " ").strip().split()
) )
gen_prefix_index = input_text.rfind(" Answer") gen_prefix_index = input_text.rfind(
" Answer: According to the chain(s) of variable assignment"
)
gen_prefix = input_text[gen_prefix_index:].strip() gen_prefix = input_text[gen_prefix_index:].strip()
# This condition is to check if we are generating the few-shot.
if icl_example is not None:
input_text = input_text[:gen_prefix_index] input_text = input_text[:gen_prefix_index]
formatted_output = { formatted_output = {
"index": index, "index": index,
...@@ -225,7 +221,10 @@ def sys_vartrack_w_noise_random( ...@@ -225,7 +221,10 @@ def sys_vartrack_w_noise_random(
def get_dataset(pretrained, seq=None, **kwargs) -> list[dict]: def get_dataset(pretrained, seq=None, **kwargs) -> list[dict]:
tokenizer = get_tokenizer(pretrained) tokenizer = get_tokenizer(pretrained)
icl_example = sys_vartrack_w_noise_random( icl_example = sys_vartrack_w_noise_random(
tokenizer=tokenizer, num_samples=1, max_seq_length=500, incremental=5 tokenizer=tokenizer,
num_samples=1,
max_seq_length=500,
incremental=5,
)[0] )[0]
write_jsons = sys_vartrack_w_noise_random( write_jsons = sys_vartrack_w_noise_random(
tokenizer=tokenizer, tokenizer=tokenizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment