Commit bb9ed305 authored by Baber's avatar Baber
Browse files

fixup

parent bf64b46f
...@@ -13,8 +13,8 @@ DEFAULT_SEQ_LENGTHS = ( ...@@ -13,8 +13,8 @@ DEFAULT_SEQ_LENGTHS = (
# 131072, # 131072,
# 65536, # 65536,
# 32768, # 32768,
16384, # 16384,
8192, # 8192,
4096, 4096,
) )
......
...@@ -21,8 +21,14 @@ from tqdm import tqdm ...@@ -21,8 +21,14 @@ from tqdm import tqdm
from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
CONFIG = {
"tokens_to_generate": 120,
"template": """Below is a numbered list of words. In these words, some appear more often than others. Memorize the ones that appear most often.\n{context}\nQuestion: What are the 10 most common words in the above list?""",
"answer_prefix": """ Answer: The top 10 words that appear most often in the list are:""",
}
RNG = random.Random(42) RNG = random.Random(42)
TEMPLATE = "Below is a numbered list of words. In these words, some appear more often than others. Memorize the ones that appear most often.\n{context}\nQuestion: What are the 10 most common words in the above list?" TEMPLATE = CONFIG["template"] + CONFIG["answer_prefix"]
r = wonderwords.RandomWord() r = wonderwords.RandomWord()
...@@ -144,17 +150,16 @@ def sys_word_pair_random( ...@@ -144,17 +150,16 @@ def sys_word_pair_random(
input_example.replace("\n", " ").replace("\t", " ").strip().split() input_example.replace("\n", " ").replace("\t", " ").strip().split()
) )
gen_prefix_index = input_text.rfind(" Answer") gen_prefix_index = input_text.rfind(CONFIG["answer_prefix"])
# gen_prefix = input_text[gen_prefix_index:].strip()
input_text = input_text[:gen_prefix_index] input_text = input_text[:gen_prefix_index]
formatted_output = { formatted_output = {
"index": index, "index": index,
"input": input_text, "input": input_text.strip(),
"input_example": input_example, "input_example": input_example,
"outputs": answer, "outputs": answer,
"length": length, "length": length,
"max_length": max_seq_length, "max_length": max_seq_length,
"gen_prefix": "Answer: The top 10 words that appear most often in the list are:", "gen_prefix": CONFIG["answer_prefix"].strip(),
} }
write_jsons.append(formatted_output) write_jsons.append(formatted_output)
......
...@@ -24,17 +24,15 @@ from tqdm import tqdm ...@@ -24,17 +24,15 @@ from tqdm import tqdm
from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
config = ( CONFIG = {
{ "tokens_to_generate": 50,
"tokens_to_generate": 50, "template": """Read the following coded text and track the frequency of each coded word. Find the three most frequently appeared coded words. {context}\nQuestion: Do not provide any explanation. Please ignore the dots '....'. What are the three most frequently appeared words in the above coded text?""",
"template": """Read the following coded text and track the frequency of each coded word. Find the three most frequently appeared coded words. {context}\nQuestion: Do not provide any explanation. Please ignore the dots '....'. What are the three most frequently appeared words in the above coded text?""", "answer_prefix": """ Answer: According to the coded text above, the three most frequently appeared words are:""",
"answer_prefix": """Answer: According to the coded text above, the three most frequently appeared words are:""", }
},
)
SEED = 42 SEED = 42
TEMPLATE = "Read the following coded text and track the frequency of each coded word. Find the three most frequently appeared coded words. {context}\nQuestion: Do not provide any explanation. Please ignore the dots '....'. What are the three most frequently appeared words in the above coded text?\n\n" TEMPLATE = CONFIG["template"] + CONFIG["answer_prefix"]
def generate_input_output( def generate_input_output(
...@@ -136,11 +134,11 @@ def sys_kwext( ...@@ -136,11 +134,11 @@ def sys_kwext(
formatted_output = { formatted_output = {
"index": index, "index": index,
"input": input_text, "input": input_text[: input_text.rfind(CONFIG["answer_prefix"])].strip(),
"outputs": answer, "outputs": answer,
"length": length, "length": length,
"max_length": max_seq_length, "max_length": max_seq_length,
"gen_prefix": "Answer: According to the coded text above, the three most frequently appeared words are:", "gen_prefix": CONFIG["answer_prefix"].strip(),
} }
write_jsons.append(formatted_output) write_jsons.append(formatted_output)
......
...@@ -9,6 +9,7 @@ download_dataset: !function niah_utils.niah_single_1 ...@@ -9,6 +9,7 @@ download_dataset: !function niah_utils.niah_single_1
doc_to_text: "{{input}}" doc_to_text: "{{input}}"
doc_to_target: "{{outputs[0]}}" doc_to_target: "{{outputs[0]}}"
gen_prefix: "{{gen_prefix}}" gen_prefix: "{{gen_prefix}}"
target_delimiter: "\n\n"
process_results: !function common_utils.process_results process_results: !function common_utils.process_results
metric_list: metric_list:
- metric: "4096" - metric: "4096"
......
...@@ -23,13 +23,13 @@ from tqdm import tqdm ...@@ -23,13 +23,13 @@ from tqdm import tqdm
from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
config = { CONFIG = {
"tokens_to_generate": 32, "tokens_to_generate": 32,
"template": """Answer the question based on the given documents. Only give me the answer and do not output any other words.\n\nThe following are given documents.\n\n{context}\n\nAnswer the question based on the given documents. Only give me the answer and do not output any other words.\n\nQuestion: {query}""", "template": """Answer the question based on the given documents. Only give me the answer and do not output any other words.\n\nThe following are given documents.\n\n{context}\n\nAnswer the question based on the given documents. Only give me the answer and do not output any other words.\n\nQuestion: {query}""",
"answer_prefix": """Answer:""", "answer_prefix": """Answer:""",
} }
SEED = 42 SEED = 42
TEMPLATE = """Answer the question based on the given documents. Only give me the answer and do not output any other words.\n\nThe following are given documents.\n\n{context}\n\nAnswer the question based on the given documents. Only give me the answer and do not output any other words.\n\nQuestion: {query}""" TEMPLATE = CONFIG["template"]
DOCUMENT_PROMPT = "Document {i}:\n{document}" DOCUMENT_PROMPT = "Document {i}:\n{document}"
......
...@@ -25,7 +25,7 @@ from tqdm import tqdm ...@@ -25,7 +25,7 @@ from tqdm import tqdm
from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
TASKS = { CONFIG = {
"variable_tracking": { "variable_tracking": {
"tokens_to_generate": 30, "tokens_to_generate": 30,
"template": """Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n{context}\nQuestion: Find all variables that are assigned the value {query} in the text above.""", "template": """Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n{context}\nQuestion: Find all variables that are assigned the value {query} in the text above.""",
...@@ -34,7 +34,8 @@ TASKS = { ...@@ -34,7 +34,8 @@ TASKS = {
} }
TEMPLATE = ( TEMPLATE = (
TASKS["variable_tracking"]["template"] + TASKS["variable_tracking"]["answer_prefix"] CONFIG["variable_tracking"]["template"]
+ CONFIG["variable_tracking"]["answer_prefix"]
) )
...@@ -97,12 +98,12 @@ def generate_input_output(num_noises, num_chains, num_hops, is_icl=False): ...@@ -97,12 +98,12 @@ def generate_input_output(num_noises, num_chains, num_hops, is_icl=False):
if ( if (
is_icl is_icl
and template and template
!= TASKS["variable_tracking"]["template"] != CONFIG["variable_tracking"]["template"]
+ TASKS["variable_tracking"]["answer_prefix"] + CONFIG["variable_tracking"]["answer_prefix"]
): ):
# remove model template # remove model template
cutoff = template.index(TASKS["variable_tracking"]["template"][:20]) cutoff = template.index(CONFIG["variable_tracking"]["template"][:20])
cutoff_ans = template.index(TASKS["variable_tracking"]["answer_prefix"][:10]) cutoff_ans = template.index(CONFIG["variable_tracking"]["answer_prefix"][:10])
template = ( template = (
" ".join(template[cutoff:cutoff_ans].split()[:-1]) + template[cutoff_ans:] " ".join(template[cutoff:cutoff_ans].split()[:-1]) + template[cutoff_ans:]
) )
...@@ -114,7 +115,7 @@ def generate_input_output(num_noises, num_chains, num_hops, is_icl=False): ...@@ -114,7 +115,7 @@ def generate_input_output(num_noises, num_chains, num_hops, is_icl=False):
def randomize_icl(icl_example: str) -> str: def randomize_icl(icl_example: str) -> str:
icl_tgt_cut = icl_example.index(TASKS["variable_tracking"]["answer_prefix"][-10:]) icl_tgt_cut = icl_example.index(CONFIG["variable_tracking"]["answer_prefix"][-10:])
icl_tgt = icl_example[icl_tgt_cut + 10 :].strip().split() icl_tgt = icl_example[icl_tgt_cut + 10 :].strip().split()
for item in icl_tgt: for item in icl_tgt:
new_item = "".join(random.choices(string.ascii_uppercase, k=len(item))).upper() new_item = "".join(random.choices(string.ascii_uppercase, k=len(item))).upper()
...@@ -186,7 +187,7 @@ def sys_vartrack_w_noise_random( ...@@ -186,7 +187,7 @@ def sys_vartrack_w_noise_random(
if add_fewshot and (icl_example is not None): if add_fewshot and (icl_example is not None):
# insert icl_example between model template and input # insert icl_example between model template and input
cutoff = input_text.index(TASKS["variable_tracking"]["template"][:20]) cutoff = input_text.index(CONFIG["variable_tracking"]["template"][:20])
input_text = ( input_text = (
input_text[:cutoff] input_text[:cutoff]
+ randomize_icl(icl_example) + randomize_icl(icl_example)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment