Commit bca6918f authored by Baber's avatar Baber
Browse files

cleanup

parent 3bbc73d4
......@@ -201,7 +201,7 @@ def generate_samples(
return write_jsons
def get_dataset(pretrained, docs, qas, max_seq_length=None, **kwargs):
def get_dataset(pretrained, docs, qas, max_seq_length=None, **kwargs) -> list[dict]:
tokenizer = get_tokenizer(pretrained)
write_jsons = generate_samples(
tokenizer=tokenizer,
......@@ -214,7 +214,7 @@ def get_dataset(pretrained, docs, qas, max_seq_length=None, **kwargs):
return write_jsons
def get_qa_dataset(ds, **kwargs):
def get_qa_dataset(ds, **kwargs) -> dict[str, datasets.Dataset]:
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
if ds == "squad":
......
group: ruler
task:
# - niah_single_1
# - niah_single_2
# - niah_single_3
# - niah_multikey_1
# - niah_multikey_2
# - niah_multikey_3
# - niah_multiquery
# - niah_multivalue
- niah_single_1
- niah_single_2
- niah_single_3
- niah_multikey_1
- niah_multikey_2
- niah_multikey_3
- niah_multiquery
- niah_multivalue
- ruler_vt
# - ruler_cwe
# - ruler_fwe
# - ruler_qa_squad
# - ruler_qa_hotpot
- ruler_cwe
- ruler_fwe
- ruler_qa_squad
- ruler_qa_hotpot
aggregate_metric_list:
- metric: acc
weight_by_size: False
......
......@@ -38,7 +38,9 @@ TEMPLATE = (
)
def generate_chains(num_chains, num_hops, is_icl=False):
def generate_chains(
num_chains: int, num_hops: int, is_icl: bool = False
) -> tuple[list[list[str]], list[list[str]]]:
vars_all = []
k = 5 if not is_icl else 3
num_hops = num_hops if not is_icl else min(10, num_hops)
......@@ -61,7 +63,9 @@ def generate_chains(num_chains, num_hops, is_icl=False):
return vars_ret, chains_ret
def generate_input_output(num_noises, num_chains, num_hops, is_icl=False):
def generate_input_output(
num_noises: int, num_chains: int, num_hops: int, is_icl: bool = False
) -> tuple[str, list[str]]:
vars, chains = generate_chains(num_chains, num_hops, is_icl=is_icl)
noise = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.\n"
......@@ -112,7 +116,7 @@ def generate_input_output(num_noises, num_chains, num_hops, is_icl=False):
return input_text, vars[0]
def randomize_icl(icl_example):
def randomize_icl(icl_example: str) -> str:
icl_tgt_cut = icl_example.index(TASKS["variable_tracking"]["answer_prefix"][-10:])
icl_tgt = icl_example[icl_tgt_cut + 10 :].strip().split()
for item in icl_tgt:
......@@ -132,7 +136,7 @@ def sys_vartrack_w_noise_random(
tokens_to_generate=30,
icl_example: dict = None,
remove_newline_tab=False,
):
) -> list[dict]:
write_jsons = []
# Find the perfect num_noises
num_noises = incremental
......@@ -218,7 +222,7 @@ def sys_vartrack_w_noise_random(
return write_jsons
def get_dataset(pretrained, seq=None, **kwargs):
def get_dataset(pretrained, seq=None, **kwargs) -> list[dict]:
tokenizer = get_tokenizer(pretrained)
icl_example = sys_vartrack_w_noise_random(
tokenizer=tokenizer, num_samples=1, max_seq_length=500, incremental=5
......@@ -232,7 +236,7 @@ def get_dataset(pretrained, seq=None, **kwargs):
return write_jsons
def get_vt_dataset(**kwargs):
def get_vt_dataset(**kwargs) -> dict[str, datasets.Dataset]:
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = (get_dataset(pretrained, seq=seq) for seq in SEQ_LENGTHS)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment