Commit bb9ed305 authored by Baber's avatar Baber
Browse files

fixup

parent bf64b46f
......@@ -13,8 +13,8 @@ DEFAULT_SEQ_LENGTHS = (
# 131072,
# 65536,
# 32768,
16384,
8192,
# 16384,
# 8192,
4096,
)
......
......@@ -21,8 +21,14 @@ from tqdm import tqdm
from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
CONFIG = {
"tokens_to_generate": 120,
"template": """Below is a numbered list of words. In these words, some appear more often than others. Memorize the ones that appear most often.\n{context}\nQuestion: What are the 10 most common words in the above list?""",
"answer_prefix": """ Answer: The top 10 words that appear most often in the list are:""",
}
RNG = random.Random(42)
TEMPLATE = "Below is a numbered list of words. In these words, some appear more often than others. Memorize the ones that appear most often.\n{context}\nQuestion: What are the 10 most common words in the above list?"
TEMPLATE = CONFIG["template"] + CONFIG["answer_prefix"]
r = wonderwords.RandomWord()
......@@ -144,17 +150,16 @@ def sys_word_pair_random(
input_example.replace("\n", " ").replace("\t", " ").strip().split()
)
gen_prefix_index = input_text.rfind(" Answer")
# gen_prefix = input_text[gen_prefix_index:].strip()
gen_prefix_index = input_text.rfind(CONFIG["answer_prefix"])
input_text = input_text[:gen_prefix_index]
formatted_output = {
"index": index,
"input": input_text,
"input": input_text.strip(),
"input_example": input_example,
"outputs": answer,
"length": length,
"max_length": max_seq_length,
"gen_prefix": "Answer: The top 10 words that appear most often in the list are:",
"gen_prefix": CONFIG["answer_prefix"].strip(),
}
write_jsons.append(formatted_output)
......
......@@ -24,17 +24,15 @@ from tqdm import tqdm
from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
config = (
{
"tokens_to_generate": 50,
"template": """Read the following coded text and track the frequency of each coded word. Find the three most frequently appeared coded words. {context}\nQuestion: Do not provide any explanation. Please ignore the dots '....'. What are the three most frequently appeared words in the above coded text?""",
"answer_prefix": """Answer: According to the coded text above, the three most frequently appeared words are:""",
},
)
CONFIG = {
"tokens_to_generate": 50,
"template": """Read the following coded text and track the frequency of each coded word. Find the three most frequently appeared coded words. {context}\nQuestion: Do not provide any explanation. Please ignore the dots '....'. What are the three most frequently appeared words in the above coded text?""",
"answer_prefix": """ Answer: According to the coded text above, the three most frequently appeared words are:""",
}
SEED = 42
TEMPLATE = "Read the following coded text and track the frequency of each coded word. Find the three most frequently appeared coded words. {context}\nQuestion: Do not provide any explanation. Please ignore the dots '....'. What are the three most frequently appeared words in the above coded text?\n\n"
TEMPLATE = CONFIG["template"] + CONFIG["answer_prefix"]
def generate_input_output(
......@@ -136,11 +134,11 @@ def sys_kwext(
formatted_output = {
"index": index,
"input": input_text,
"input": input_text[: input_text.rfind(CONFIG["answer_prefix"])].strip(),
"outputs": answer,
"length": length,
"max_length": max_seq_length,
"gen_prefix": "Answer: According to the coded text above, the three most frequently appeared words are:",
"gen_prefix": CONFIG["answer_prefix"].strip(),
}
write_jsons.append(formatted_output)
......
......@@ -9,6 +9,7 @@ download_dataset: !function niah_utils.niah_single_1
doc_to_text: "{{input}}"
doc_to_target: "{{outputs[0]}}"
gen_prefix: "{{gen_prefix}}"
target_delimiter: "\n\n"
process_results: !function common_utils.process_results
metric_list:
- metric: "4096"
......
......@@ -23,13 +23,13 @@ from tqdm import tqdm
from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
config = {
CONFIG = {
"tokens_to_generate": 32,
"template": """Answer the question based on the given documents. Only give me the answer and do not output any other words.\n\nThe following are given documents.\n\n{context}\n\nAnswer the question based on the given documents. Only give me the answer and do not output any other words.\n\nQuestion: {query}""",
"answer_prefix": """Answer:""",
}
SEED = 42
TEMPLATE = """Answer the question based on the given documents. Only give me the answer and do not output any other words.\n\nThe following are given documents.\n\n{context}\n\nAnswer the question based on the given documents. Only give me the answer and do not output any other words.\n\nQuestion: {query}"""
TEMPLATE = CONFIG["template"]
DOCUMENT_PROMPT = "Document {i}:\n{document}"
......
......@@ -25,7 +25,7 @@ from tqdm import tqdm
from lm_eval.tasks.ruler.common_utils import DEFAULT_SEQ_LENGTHS, get_tokenizer
TASKS = {
CONFIG = {
"variable_tracking": {
"tokens_to_generate": 30,
"template": """Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n{context}\nQuestion: Find all variables that are assigned the value {query} in the text above.""",
......@@ -34,7 +34,8 @@ TASKS = {
}
TEMPLATE = (
TASKS["variable_tracking"]["template"] + TASKS["variable_tracking"]["answer_prefix"]
CONFIG["variable_tracking"]["template"]
+ CONFIG["variable_tracking"]["answer_prefix"]
)
......@@ -97,12 +98,12 @@ def generate_input_output(num_noises, num_chains, num_hops, is_icl=False):
if (
is_icl
and template
!= TASKS["variable_tracking"]["template"]
+ TASKS["variable_tracking"]["answer_prefix"]
!= CONFIG["variable_tracking"]["template"]
+ CONFIG["variable_tracking"]["answer_prefix"]
):
# remove model template
cutoff = template.index(TASKS["variable_tracking"]["template"][:20])
cutoff_ans = template.index(TASKS["variable_tracking"]["answer_prefix"][:10])
cutoff = template.index(CONFIG["variable_tracking"]["template"][:20])
cutoff_ans = template.index(CONFIG["variable_tracking"]["answer_prefix"][:10])
template = (
" ".join(template[cutoff:cutoff_ans].split()[:-1]) + template[cutoff_ans:]
)
......@@ -114,7 +115,7 @@ def generate_input_output(num_noises, num_chains, num_hops, is_icl=False):
def randomize_icl(icl_example: str) -> str:
icl_tgt_cut = icl_example.index(TASKS["variable_tracking"]["answer_prefix"][-10:])
icl_tgt_cut = icl_example.index(CONFIG["variable_tracking"]["answer_prefix"][-10:])
icl_tgt = icl_example[icl_tgt_cut + 10 :].strip().split()
for item in icl_tgt:
new_item = "".join(random.choices(string.ascii_uppercase, k=len(item))).upper()
......@@ -186,7 +187,7 @@ def sys_vartrack_w_noise_random(
if add_fewshot and (icl_example is not None):
# insert icl_example between model template and input
cutoff = input_text.index(TASKS["variable_tracking"]["template"][:20])
cutoff = input_text.index(CONFIG["variable_tracking"]["template"][:20])
input_text = (
input_text[:cutoff]
+ randomize_icl(icl_example)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment