# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # adapted from https://github.com/NVIDIA/RULER/blob/main/scripts/data/synthetic/variable_tracking.py import itertools import random import string import datasets import numpy as np from tqdm import tqdm from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer TASKS = { "variable_tracking": { "tokens_to_generate": 30, "template": """Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n{context}\nQuestion: Find all variables that are assigned the value {query} in the text above.""", "answer_prefix": """ Answer: According to the chain(s) of variable assignment in the text above, {num_v} variables are assgined the value {query}, they are: """, }, } TEMPLATE = ( """Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n{context}\nQuestion: Find all variables that are assigned the value {query} in the text above.""" + TASKS["variable_tracking"]["answer_prefix"] ) def generate_chains( num_chains: int, num_hops: int, is_icl: bool = False ) -> tuple[list[list[str]], list[list[str]]]: vars_all = [] k = 5 if not is_icl else 3 num_hops = num_hops if not is_icl else min(10, num_hops) vars_all = [ "".join(random.choices(string.ascii_uppercase, k=k)).upper() for _ in range((num_hops + 1) * num_chains) ] while len(set(vars_all)) < num_chains * (num_hops + 1): vars_all.append("".join(random.choices(string.ascii_uppercase, k=k)).upper()) vars_ret = [] chains_ret = [] for i in range(0, len(vars_all), num_hops + 1): this_vars = vars_all[i : i + num_hops + 1] vars_ret.append(this_vars) this_chain = [f"VAR {this_vars[0]} = {np.random.randint(10000, 99999)}"] for j in range(num_hops): this_chain.append(f"VAR {this_vars[j + 1]} = VAR {this_vars[j]} ") chains_ret.append(this_chain) return vars_ret, chains_ret def generate_input_output( num_noises: int, num_chains: int, num_hops: int, is_icl: bool = False ) -> tuple[str, list[str]]: vars, chains = generate_chains(num_chains, num_hops, is_icl=is_icl) noise = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.\n" # Create a list of the repeated noise sentences = [noise] * num_noises if len(sentences) <= len(chains[0]): sentences = [ n + "." if len(n.strip()) > 0 else n for n in [x for noise in sentences for x in noise.split(".")] ] try: assert len(sentences) > len(chains[0]), ( "Noises too short, unable to generate data" ) # ruff: noqa except: print("reduces chain length for not enough noises") chains = [chain[: len(sentences) - 1] for chain in chains] # sample random positions to insert variable assignment for chain_i in chains: # sample random positions (sorted) to insert variable assignment positions = list(sorted(random.sample(range(len(sentences)), len(chain_i)))) for insert_pi, j in zip(positions, range(len(chain_i))): sentences.insert(insert_pi + j, chain_i[j]) # Insert the passkey sentence at the random position context = " ".join(sentences) context = context.replace(". \n", ".\n") template = TEMPLATE if ( is_icl and template != TASKS["variable_tracking"]["template"] + TASKS["variable_tracking"]["answer_prefix"] ): # remove model template cutoff = template.index(TASKS["variable_tracking"]["template"][:20]) cutoff_ans = template.index(TASKS["variable_tracking"]["answer_prefix"][:10]) template = ( " ".join(template[cutoff:cutoff_ans].split()[:-1]) + template[cutoff_ans:] ) value = chains[0][0].split("=")[-1].strip() input_text = template.format(context=context, query=value, num_v=num_hops + 1) return input_text, vars[0] def randomize_icl(icl_example: str) -> str: icl_tgt_cut = icl_example.index(TASKS["variable_tracking"]["answer_prefix"][-10:]) icl_tgt = icl_example[icl_tgt_cut + 10 :].strip().split() for item in icl_tgt: new_item = "".join(random.choices(string.ascii_uppercase, k=len(item))).upper() icl_example = icl_example.replace(item, new_item) return icl_example def sys_vartrack_w_noise_random( tokenizer, num_samples: int, max_seq_length: int, incremental: int = 10, num_chains: int = 1, num_hops: int = 4, add_fewshot: bool = True, tokens_to_generate=30, icl_example: dict = None, remove_newline_tab=False, ) -> list[dict]: write_jsons = [] # Find the perfect num_noises num_noises = incremental TOKENIZER = tokenizer total_tokens = 0 # Track the total tokens generated for this example example_tokens = 0 if add_fewshot and (icl_example is not None): icl_example_out = " ".join(icl_example["outputs"]) prefix = icl_example["gen_prefix"] icl_example = ( icl_example["input"] + " " + prefix + " " + icl_example_out + "\n\n" ) example_tokens = len(TOKENIZER(icl_example).input_ids) while total_tokens + tokens_to_generate + example_tokens < max_seq_length: input_text, answer = generate_input_output( num_noises, num_chains, num_hops, is_icl=add_fewshot & (icl_example is None) ) # Calculate the number of tokens in the example total_tokens = len(TOKENIZER(input_text + f" {answer}").input_ids) # print( # f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate + example_tokens} | Noises: {num_noises}" # ) if total_tokens + tokens_to_generate + example_tokens > max_seq_length: num_noises -= incremental break num_noises += incremental # print("Num noises:", num_noises) # Generate samples for index in tqdm( range(num_samples), desc=f"Generating VT Samples| {max_seq_length}" ): used_noises = num_noises while True: try: input_text, answer = generate_input_output( used_noises, num_chains, num_hops, is_icl=add_fewshot & (icl_example is None), ) length = ( len(TOKENIZER(input_text).input_ids) + tokens_to_generate + example_tokens ) assert length <= max_seq_length, f"{length} exceeds max_seq_length." break # ruff: noqa except: if used_noises > incremental: used_noises -= incremental if add_fewshot and (icl_example is not None): # insert icl_example between model template and input cutoff = input_text.index(TASKS["variable_tracking"]["template"][:20]) input_text = ( input_text[:cutoff] + randomize_icl(icl_example) + "\n\n" + input_text[cutoff:] ) if remove_newline_tab: input_text = " ".join( input_text.replace("\n", " ").replace("\t", " ").strip().split() ) gen_prefix_index = input_text.rfind(" Answer") gen_prefix = input_text[gen_prefix_index:].strip() input_text = input_text[:gen_prefix_index] formatted_output = { "index": index, "input": input_text, "outputs": answer, "length": length, "max_length": max_seq_length, "gen_prefix": gen_prefix, } write_jsons.append(formatted_output) return write_jsons def get_dataset(pretrained, seq=None, **kwargs) -> list[dict]: tokenizer = get_tokenizer(pretrained) icl_example = sys_vartrack_w_noise_random( tokenizer=tokenizer, num_samples=1, max_seq_length=500, incremental=5 )[0] write_jsons = sys_vartrack_w_noise_random( tokenizer=tokenizer, num_samples=500, max_seq_length=seq, icl_example=icl_example, ) return write_jsons def get_vt_dataset(**kwargs) -> dict[str, datasets.Dataset]: pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {})) df = ( get_dataset(pretrained, seq=seq) for seq in kwargs.pop("max_seq_lengths", DEFAULT_SEQ_LENGTHS) ) return { "test": datasets.Dataset.from_list( list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST ) }