Commit bf64b46f authored by Baber's avatar Baber
Browse files

fix vt

parent 83ca7278
......@@ -24,17 +24,17 @@ from tqdm import tqdm
from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
TASKS = {
"variable_tracking": {
"tokens_to_generate": 30,
"template": """Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n{context}\nQuestion: Find all variables that are assigned the value {query} in the text above.""",
"answer_prefix": """Answer: According to the chain(s) of variable assignment in the text above, {num_v} variables are assigned the value {query}, they are:""",
"answer_prefix": """ Answer: According to the chain(s) of variable assignment in the text above, {num_v} variables are assgined the value {query}, they are: """,
},
}
TEMPLATE = (
"""Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n{context}\nQuestion: Find all variables that are assigned the value {query} in the text above."""
+ TASKS["variable_tracking"]["answer_prefix"]
TASKS["variable_tracking"]["template"] + TASKS["variable_tracking"]["answer_prefix"]
)
......@@ -63,9 +63,7 @@ def generate_chains(
return vars_ret, chains_ret
def generate_input_output(
num_noises: int, num_chains: int, num_hops: int, is_icl: bool = False
) -> tuple[str, list[str]]:
def generate_input_output(num_noises, num_chains, num_hops, is_icl=False):
vars, chains = generate_chains(num_chains, num_hops, is_icl=is_icl)
noise = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.\n"
......@@ -81,8 +79,7 @@ def generate_input_output(
assert len(sentences) > len(chains[0]), (
"Noises too short, unable to generate data"
)
# ruff: noqa
except:
except: # noqa: E722
print("reduces chain length for not enough noises")
chains = [chain[: len(sentences) - 1] for chain in chains]
# sample random positions to insert variable assignment
......@@ -136,41 +133,37 @@ def sys_vartrack_w_noise_random(
tokens_to_generate=30,
icl_example: dict = None,
remove_newline_tab=False,
) -> list[dict]:
):
write_jsons = []
tokens_to_generate = tokens_to_generate
# Find the perfect num_noises
num_noises = incremental
TOKENIZER = tokenizer
total_tokens = 0 # Track the total tokens generated for this example
example_tokens = 0
if add_fewshot and (icl_example is not None):
icl_example_out = " ".join(icl_example["outputs"])
prefix = icl_example["gen_prefix"]
icl_example = (
icl_example["input"] + " " + prefix + " " + icl_example_out + "\n\n"
)
example_tokens = len(TOKENIZER(icl_example).input_ids)
icl_example = icl_example["input"] + " " + icl_example_out + "\n\n"
example_tokens = len(tokenizer(icl_example).input_ids)
while total_tokens + tokens_to_generate + example_tokens < max_seq_length:
input_text, answer = generate_input_output(
num_noises, num_chains, num_hops, is_icl=add_fewshot & (icl_example is None)
)
# Calculate the number of tokens in the example
total_tokens = len(TOKENIZER(input_text + f" {answer}").input_ids)
# print(
# f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate + example_tokens} | Noises: {num_noises}"
# )
total_tokens = len(tokenizer(input_text + f" {answer}").input_ids)
print(
f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate + example_tokens} | Noises: {num_noises}"
)
if total_tokens + tokens_to_generate + example_tokens > max_seq_length:
num_noises -= incremental
break
num_noises += incremental
# print("Num noises:", num_noises)
print("Num noises:", num_noises)
# Generate samples
for index in tqdm(
range(num_samples), desc=f"Generating VT Samples| {max_seq_length}"
):
for index in tqdm(range(num_samples)):
used_noises = num_noises
while True:
try:
......@@ -181,14 +174,13 @@ def sys_vartrack_w_noise_random(
is_icl=add_fewshot & (icl_example is None),
)
length = (
len(TOKENIZER(input_text).input_ids)
len(tokenizer(input_text).input_ids)
+ tokens_to_generate
+ example_tokens
)
assert length <= max_seq_length, f"{length} exceeds max_seq_length."
break
# ruff: noqa
except:
except: # noqa: E722
if used_noises > incremental:
used_noises -= incremental
......@@ -206,8 +198,12 @@ def sys_vartrack_w_noise_random(
input_text.replace("\n", " ").replace("\t", " ").strip().split()
)
gen_prefix_index = input_text.rfind(" Answer")
gen_prefix_index = input_text.rfind(
" Answer: According to the chain(s) of variable assignment"
)
gen_prefix = input_text[gen_prefix_index:].strip()
# This condition is to check if we are generating the few-shot.
if icl_example is not None:
input_text = input_text[:gen_prefix_index]
formatted_output = {
"index": index,
......@@ -225,7 +221,10 @@ def sys_vartrack_w_noise_random(
def get_dataset(pretrained, seq=None, **kwargs) -> list[dict]:
tokenizer = get_tokenizer(pretrained)
icl_example = sys_vartrack_w_noise_random(
tokenizer=tokenizer, num_samples=1, max_seq_length=500, incremental=5
tokenizer=tokenizer,
num_samples=1,
max_seq_length=500,
incremental=5,
)[0]
write_jsons = sys_vartrack_w_noise_random(
tokenizer=tokenizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment