Commit bca6918f authored by Baber's avatar Baber
Browse files

cleanup

parent 3bbc73d4
...@@ -201,7 +201,7 @@ def generate_samples( ...@@ -201,7 +201,7 @@ def generate_samples(
return write_jsons return write_jsons
def get_dataset(pretrained, docs, qas, max_seq_length=None, **kwargs): def get_dataset(pretrained, docs, qas, max_seq_length=None, **kwargs) -> list[dict]:
tokenizer = get_tokenizer(pretrained) tokenizer = get_tokenizer(pretrained)
write_jsons = generate_samples( write_jsons = generate_samples(
tokenizer=tokenizer, tokenizer=tokenizer,
...@@ -214,7 +214,7 @@ def get_dataset(pretrained, docs, qas, max_seq_length=None, **kwargs): ...@@ -214,7 +214,7 @@ def get_dataset(pretrained, docs, qas, max_seq_length=None, **kwargs):
return write_jsons return write_jsons
def get_qa_dataset(ds, **kwargs): def get_qa_dataset(ds, **kwargs) -> dict[str, datasets.Dataset]:
kwargs = kwargs.get("metadata", {}) kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {})) pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
if ds == "squad": if ds == "squad":
......
group: ruler group: ruler
task: task:
# - niah_single_1 - niah_single_1
# - niah_single_2 - niah_single_2
# - niah_single_3 - niah_single_3
# - niah_multikey_1 - niah_multikey_1
# - niah_multikey_2 - niah_multikey_2
# - niah_multikey_3 - niah_multikey_3
# - niah_multiquery - niah_multiquery
# - niah_multivalue - niah_multivalue
- ruler_vt - ruler_vt
# - ruler_cwe - ruler_cwe
# - ruler_fwe - ruler_fwe
# - ruler_qa_squad - ruler_qa_squad
# - ruler_qa_hotpot - ruler_qa_hotpot
aggregate_metric_list: aggregate_metric_list:
- metric: acc - metric: acc
weight_by_size: False weight_by_size: False
......
...@@ -38,7 +38,9 @@ TEMPLATE = ( ...@@ -38,7 +38,9 @@ TEMPLATE = (
) )
def generate_chains(num_chains, num_hops, is_icl=False): def generate_chains(
num_chains: int, num_hops: int, is_icl: bool = False
) -> tuple[list[list[str]], list[list[str]]]:
vars_all = [] vars_all = []
k = 5 if not is_icl else 3 k = 5 if not is_icl else 3
num_hops = num_hops if not is_icl else min(10, num_hops) num_hops = num_hops if not is_icl else min(10, num_hops)
...@@ -61,7 +63,9 @@ def generate_chains(num_chains, num_hops, is_icl=False): ...@@ -61,7 +63,9 @@ def generate_chains(num_chains, num_hops, is_icl=False):
return vars_ret, chains_ret return vars_ret, chains_ret
def generate_input_output(num_noises, num_chains, num_hops, is_icl=False): def generate_input_output(
num_noises: int, num_chains: int, num_hops: int, is_icl: bool = False
) -> tuple[str, list[str]]:
vars, chains = generate_chains(num_chains, num_hops, is_icl=is_icl) vars, chains = generate_chains(num_chains, num_hops, is_icl=is_icl)
noise = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.\n" noise = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.\n"
...@@ -112,7 +116,7 @@ def generate_input_output(num_noises, num_chains, num_hops, is_icl=False): ...@@ -112,7 +116,7 @@ def generate_input_output(num_noises, num_chains, num_hops, is_icl=False):
return input_text, vars[0] return input_text, vars[0]
def randomize_icl(icl_example): def randomize_icl(icl_example: str) -> str:
icl_tgt_cut = icl_example.index(TASKS["variable_tracking"]["answer_prefix"][-10:]) icl_tgt_cut = icl_example.index(TASKS["variable_tracking"]["answer_prefix"][-10:])
icl_tgt = icl_example[icl_tgt_cut + 10 :].strip().split() icl_tgt = icl_example[icl_tgt_cut + 10 :].strip().split()
for item in icl_tgt: for item in icl_tgt:
...@@ -132,7 +136,7 @@ def sys_vartrack_w_noise_random( ...@@ -132,7 +136,7 @@ def sys_vartrack_w_noise_random(
tokens_to_generate=30, tokens_to_generate=30,
icl_example: dict = None, icl_example: dict = None,
remove_newline_tab=False, remove_newline_tab=False,
): ) -> list[dict]:
write_jsons = [] write_jsons = []
# Find the perfect num_noises # Find the perfect num_noises
num_noises = incremental num_noises = incremental
...@@ -218,7 +222,7 @@ def sys_vartrack_w_noise_random( ...@@ -218,7 +222,7 @@ def sys_vartrack_w_noise_random(
return write_jsons return write_jsons
def get_dataset(pretrained, seq=None, **kwargs): def get_dataset(pretrained, seq=None, **kwargs) -> list[dict]:
tokenizer = get_tokenizer(pretrained) tokenizer = get_tokenizer(pretrained)
icl_example = sys_vartrack_w_noise_random( icl_example = sys_vartrack_w_noise_random(
tokenizer=tokenizer, num_samples=1, max_seq_length=500, incremental=5 tokenizer=tokenizer, num_samples=1, max_seq_length=500, incremental=5
...@@ -232,7 +236,7 @@ def get_dataset(pretrained, seq=None, **kwargs): ...@@ -232,7 +236,7 @@ def get_dataset(pretrained, seq=None, **kwargs):
return write_jsons return write_jsons
def get_vt_dataset(**kwargs): def get_vt_dataset(**kwargs) -> dict[str, datasets.Dataset]:
kwargs = kwargs.get("metadata", {}) kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {})) pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = (get_dataset(pretrained, seq=seq) for seq in SEQ_LENGTHS) df = (get_dataset(pretrained, seq=seq) for seq in SEQ_LENGTHS)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment