Unverified Commit b50d1632 authored by Mo Li's avatar Mo Li Committed by GitHub
Browse files

[Fix] Refactor Needlebench Configs for CLI Testing Support (#1020)

* add needlebench datasets suffix

* fix import

* update run.py args for summarizer key and dataset suffix

* update utils/run.py
parent 2d4e5597
......@@ -59,7 +59,7 @@ document_depth_percent_interval_type = "linear"
base_path = './data/needlebench'
file_list = ['PaulGrahamEssays.jsonl']
needlebench_datasets_en = []
needlebench_en_datasets = []
needle_file_name = 'needles.jsonl'
for original_context_length in context_lengths:
......@@ -82,10 +82,10 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_en.append(dataset_dict)
needlebench_en_datasets.append(dataset_dict)
file_list = ['zh_finance.jsonl']
needlebench_datasets_zh = []
needlebench_zh_datasets = []
needle_file_name = 'needles.jsonl'
for original_context_length in context_lengths:
......@@ -108,4 +108,4 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_zh.append(dataset_dict)
needlebench_zh_datasets.append(dataset_dict)
from mmengine.config import read_base
with read_base():
from .needlebench_multi_reasoning import needlebench_datasets_2needle_en as needlebench_multi_2needle_en_datasets
from .needlebench_multi_reasoning import needlebench_datasets_3needle_en as needlebench_multi_3needle_en_datasets
from .needlebench_multi_reasoning import needlebench_datasets_4needle_en as needlebench_multi_4needle_en_datasets
from .needlebench_multi_reasoning import needlebench_datasets_5needle_en as needlebench_multi_5needle_en_datasets
from .needlebench_multi_reasoning import needlebench_datasets_2needle_zh as needlebench_multi_2needle_zh_datasets
from .needlebench_multi_reasoning import needlebench_datasets_3needle_zh as needlebench_multi_3needle_zh_datasets
from .needlebench_multi_reasoning import needlebench_datasets_4needle_zh as needlebench_multi_4needle_zh_datasets
from .needlebench_multi_reasoning import needlebench_datasets_5needle_zh as needlebench_multi_5needle_zh_datasets
from .needlebench_single import needlebench_datasets_en as needlebench_origin_en_datasets
from .needlebench_single import needlebench_datasets_zh as needlebench_origin_zh_datasets
from .needlebench_multi_retrieval import needlebench_datasets_en as needlebench_parallel_en_datasets
from .needlebench_multi_retrieval import needlebench_datasets_zh as needlebench_parallel_zh_datasets
needlebench_datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
from mmengine.config import read_base
with read_base():
from .needlebench_multi_reasoning_4k import needlebench_2needle_en_datasets as needlebench_multi_2needle_en_datasets
from .needlebench_multi_reasoning_4k import needlebench_3needle_en_datasets as needlebench_multi_3needle_en_datasets
from .needlebench_multi_reasoning_4k import needlebench_4needle_en_datasets as needlebench_multi_4needle_en_datasets
from .needlebench_multi_reasoning_4k import needlebench_5needle_en_datasets as needlebench_multi_5needle_en_datasets
from .needlebench_multi_reasoning_4k import needlebench_2needle_zh_datasets as needlebench_multi_2needle_zh_datasets
from .needlebench_multi_reasoning_4k import needlebench_3needle_zh_datasets as needlebench_multi_3needle_zh_datasets
from .needlebench_multi_reasoning_4k import needlebench_4needle_zh_datasets as needlebench_multi_4needle_zh_datasets
from .needlebench_multi_reasoning_4k import needlebench_5needle_zh_datasets as needlebench_multi_5needle_zh_datasets
from .needlebench_single_4k import needlebench_en_datasets as needlebench_origin_en_datasets
from .needlebench_single_4k import needlebench_zh_datasets as needlebench_origin_zh_datasets
from .needlebench_multi_retrieval_4k import needlebench_en_datasets as needlebench_parallel_en_datasets
from .needlebench_multi_retrieval_4k import needlebench_zh_datasets as needlebench_parallel_zh_datasets
needlebench_datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
......@@ -63,7 +63,7 @@ file_list = ['PaulGrahamEssays.jsonl']
needle_file_name = 'multi_needle_reasoning_en.json'
diff = 10
num_needles = 2
needlebench_datasets_2needle_en = []
needlebench_2needle_en_datasets = []
language = 'English'
for original_context_length in context_lengths:
......@@ -90,10 +90,10 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_2needle_en.append(dataset_dict)
needlebench_2needle_en_datasets.append(dataset_dict)
num_needles = 3
needlebench_datasets_3needle_en = []
needlebench_3needle_en_datasets = []
for original_context_length in context_lengths:
for depth_percent in generate_depth_percents(
......@@ -119,10 +119,10 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_3needle_en.append(dataset_dict)
needlebench_3needle_en_datasets.append(dataset_dict)
num_needles = 4
needlebench_datasets_4needle_en = []
needlebench_4needle_en_datasets = []
for original_context_length in context_lengths:
for depth_percent in generate_depth_percents(
......@@ -148,10 +148,10 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_4needle_en.append(dataset_dict)
needlebench_4needle_en_datasets.append(dataset_dict)
num_needles = 5
needlebench_datasets_5needle_en = []
needlebench_5needle_en_datasets = []
for original_context_length in context_lengths:
for depth_percent in generate_depth_percents(
......@@ -177,7 +177,7 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_5needle_en.append(dataset_dict)
needlebench_5needle_en_datasets.append(dataset_dict)
# ----------Chinese Version----------
base_path = './data/needlebench'
......@@ -186,7 +186,7 @@ file_list = ['zh_finance.jsonl']
needle_file_name = 'multi_needle_reasoning_zh.json'
diff = 10
num_needles = 2
needlebench_datasets_2needle_zh = []
needlebench_2needle_zh_datasets = []
language = 'Chinese'
for original_context_length in context_lengths:
......@@ -213,10 +213,10 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_2needle_zh.append(dataset_dict)
needlebench_2needle_zh_datasets.append(dataset_dict)
num_needles = 3
needlebench_datasets_3needle_zh = []
needlebench_3needle_zh_datasets = []
for original_context_length in context_lengths:
for depth_percent in generate_depth_percents(
......@@ -242,10 +242,10 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_3needle_zh.append(dataset_dict)
needlebench_3needle_zh_datasets.append(dataset_dict)
num_needles = 4
needlebench_datasets_4needle_zh = []
needlebench_4needle_zh_datasets = []
for original_context_length in context_lengths:
for depth_percent in generate_depth_percents(
......@@ -271,10 +271,10 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_4needle_zh.append(dataset_dict)
needlebench_4needle_zh_datasets.append(dataset_dict)
num_needles = 5
needlebench_datasets_5needle_zh = []
needlebench_5needle_zh_datasets = []
for original_context_length in context_lengths:
for depth_percent in generate_depth_percents(
......@@ -300,4 +300,4 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_5needle_zh.append(dataset_dict)
needlebench_5needle_zh_datasets.append(dataset_dict)
......@@ -58,7 +58,7 @@ document_depth_percent_interval_type = "linear"
base_path = './data/needlebench'
file_list = ['PaulGrahamEssays.jsonl']
needlebench_datasets_en = []
needlebench_en_datasets = []
needle_file_name = 'needles.jsonl'
depths_float = generate_depth_percents(
document_depth_percent_intervals,
......@@ -84,10 +84,10 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_en.append(dataset_dict)
needlebench_en_datasets.append(dataset_dict)
file_list = ['zh_finance.jsonl']
needlebench_datasets_zh = []
needlebench_zh_datasets = []
for original_context_length in context_lengths:
dataset_dict = {
......@@ -108,4 +108,4 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_zh.append(dataset_dict)
needlebench_zh_datasets.append(dataset_dict)
......@@ -58,7 +58,7 @@ document_depth_percent_interval_type = "linear"
base_path = './data/needlebench'
file_list = ['PaulGrahamEssays.jsonl']
needlebench_datasets_en = []
needlebench_en_datasets = []
needle_file_name = 'needles.jsonl'
for original_context_length in context_lengths:
......@@ -83,10 +83,10 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_en.append(dataset_dict)
needlebench_en_datasets.append(dataset_dict)
file_list = ['zh_finance.jsonl']
needlebench_datasets_zh = []
needlebench_zh_datasets = []
needle_file_name = 'needles.jsonl'
for original_context_length in context_lengths:
......@@ -111,4 +111,4 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_zh.append(dataset_dict)
needlebench_zh_datasets.append(dataset_dict)
from mmengine.config import read_base
with read_base():
from .needlebench_multi_reasoning import needlebench_datasets_2needle_en as needlebench_multi_2needle_en_datasets
from .needlebench_multi_reasoning import needlebench_datasets_3needle_en as needlebench_multi_3needle_en_datasets
from .needlebench_multi_reasoning import needlebench_datasets_4needle_en as needlebench_multi_4needle_en_datasets
from .needlebench_multi_reasoning import needlebench_datasets_5needle_en as needlebench_multi_5needle_en_datasets
from .needlebench_multi_reasoning import needlebench_datasets_2needle_zh as needlebench_multi_2needle_zh_datasets
from .needlebench_multi_reasoning import needlebench_datasets_3needle_zh as needlebench_multi_3needle_zh_datasets
from .needlebench_multi_reasoning import needlebench_datasets_4needle_zh as needlebench_multi_4needle_zh_datasets
from .needlebench_multi_reasoning import needlebench_datasets_5needle_zh as needlebench_multi_5needle_zh_datasets
from .needlebench_single import needlebench_datasets_en as needlebench_origin_en_datasets
from .needlebench_single import needlebench_datasets_zh as needlebench_origin_zh_datasets
from .needlebench_multi_retrieval import needlebench_datasets_en as needlebench_parallel_en_datasets
from .needlebench_multi_retrieval import needlebench_datasets_zh as needlebench_parallel_zh_datasets
needlebench_datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
from mmengine.config import read_base
with read_base():
from .needlebench_multi_reasoning_8k import needlebench_2needle_en_datasets as needlebench_multi_2needle_en_datasets
from .needlebench_multi_reasoning_8k import needlebench_3needle_en_datasets as needlebench_multi_3needle_en_datasets
from .needlebench_multi_reasoning_8k import needlebench_4needle_en_datasets as needlebench_multi_4needle_en_datasets
from .needlebench_multi_reasoning_8k import needlebench_5needle_en_datasets as needlebench_multi_5needle_en_datasets
from .needlebench_multi_reasoning_8k import needlebench_2needle_zh_datasets as needlebench_multi_2needle_zh_datasets
from .needlebench_multi_reasoning_8k import needlebench_3needle_zh_datasets as needlebench_multi_3needle_zh_datasets
from .needlebench_multi_reasoning_8k import needlebench_4needle_zh_datasets as needlebench_multi_4needle_zh_datasets
from .needlebench_multi_reasoning_8k import needlebench_5needle_zh_datasets as needlebench_multi_5needle_zh_datasets
from .needlebench_single_8k import needlebench_en_datasets as needlebench_origin_en_datasets
from .needlebench_single_8k import needlebench_zh_datasets as needlebench_origin_zh_datasets
from .needlebench_multi_retrieval_8k import needlebench_en_datasets as needlebench_parallel_en_datasets
from .needlebench_multi_retrieval_8k import needlebench_zh_datasets as needlebench_parallel_zh_datasets
needlebench_datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
......@@ -63,7 +63,7 @@ file_list = ['PaulGrahamEssays.jsonl']
needle_file_name = 'multi_needle_reasoning_en.json'
diff = 10
num_needles = 2
needlebench_datasets_2needle_en = []
needlebench_2needle_en_datasets = []
language = 'English'
for original_context_length in context_lengths:
......@@ -90,10 +90,10 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_2needle_en.append(dataset_dict)
needlebench_2needle_en_datasets.append(dataset_dict)
num_needles = 3
needlebench_datasets_3needle_en = []
needlebench_3needle_en_datasets = []
for original_context_length in context_lengths:
for depth_percent in generate_depth_percents(
......@@ -119,10 +119,10 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_3needle_en.append(dataset_dict)
needlebench_3needle_en_datasets.append(dataset_dict)
num_needles = 4
needlebench_datasets_4needle_en = []
needlebench_4needle_en_datasets = []
for original_context_length in context_lengths:
for depth_percent in generate_depth_percents(
......@@ -148,10 +148,10 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_4needle_en.append(dataset_dict)
needlebench_4needle_en_datasets.append(dataset_dict)
num_needles = 5
needlebench_datasets_5needle_en = []
needlebench_5needle_en_datasets = []
for original_context_length in context_lengths:
for depth_percent in generate_depth_percents(
......@@ -177,7 +177,7 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_5needle_en.append(dataset_dict)
needlebench_5needle_en_datasets.append(dataset_dict)
# ----------Chinese Version----------
base_path = './data/needlebench'
......@@ -186,7 +186,7 @@ file_list = ['zh_finance.jsonl']
needle_file_name = 'multi_needle_reasoning_zh.json'
diff = 10
num_needles = 2
needlebench_datasets_2needle_zh = []
needlebench_2needle_zh_datasets = []
language = 'Chinese'
for original_context_length in context_lengths:
......@@ -213,10 +213,10 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_2needle_zh.append(dataset_dict)
needlebench_2needle_zh_datasets.append(dataset_dict)
num_needles = 3
needlebench_datasets_3needle_zh = []
needlebench_3needle_zh_datasets = []
for original_context_length in context_lengths:
for depth_percent in generate_depth_percents(
......@@ -242,10 +242,10 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_3needle_zh.append(dataset_dict)
needlebench_3needle_zh_datasets.append(dataset_dict)
num_needles = 4
needlebench_datasets_4needle_zh = []
needlebench_4needle_zh_datasets = []
for original_context_length in context_lengths:
for depth_percent in generate_depth_percents(
......@@ -271,10 +271,10 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_4needle_zh.append(dataset_dict)
needlebench_4needle_zh_datasets.append(dataset_dict)
num_needles = 5
needlebench_datasets_5needle_zh = []
needlebench_5needle_zh_datasets = []
for original_context_length in context_lengths:
for depth_percent in generate_depth_percents(
......@@ -300,4 +300,4 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_5needle_zh.append(dataset_dict)
needlebench_5needle_zh_datasets.append(dataset_dict)
......@@ -58,7 +58,7 @@ document_depth_percent_interval_type = "linear"
base_path = './data/needlebench'
file_list = ['PaulGrahamEssays.jsonl']
needlebench_datasets_en = []
needlebench_en_datasets = []
needle_file_name = 'needles.jsonl'
depths_float = generate_depth_percents(
document_depth_percent_intervals,
......@@ -84,10 +84,10 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_en.append(dataset_dict)
needlebench_en_datasets.append(dataset_dict)
file_list = ['zh_finance.jsonl']
needlebench_datasets_zh = []
needlebench_zh_datasets = []
for original_context_length in context_lengths:
dataset_dict = {
......@@ -108,4 +108,4 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_zh.append(dataset_dict)
needlebench_zh_datasets.append(dataset_dict)
......@@ -58,7 +58,7 @@ document_depth_percent_interval_type = "linear"
base_path = './data/needlebench'
file_list = ['PaulGrahamEssays.jsonl']
needlebench_datasets_en = []
needlebench_en_datasets = []
needle_file_name = 'needles.jsonl'
for document_depth_percent_intervals in document_depth_percent_intervals_list:
......@@ -86,10 +86,10 @@ for document_depth_percent_intervals in document_depth_percent_intervals_list:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_en.append(dataset_dict)
needlebench_en_datasets.append(dataset_dict)
file_list = ['zh_finance.jsonl']
needlebench_datasets_zh = []
needlebench_zh_datasets = []
needle_file_name = 'needles.jsonl'
for document_depth_percent_intervals in document_depth_percent_intervals_list:
......@@ -117,4 +117,4 @@ for document_depth_percent_intervals in document_depth_percent_intervals_list:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_zh.append(dataset_dict)
needlebench_zh_datasets.append(dataset_dict)
......@@ -58,7 +58,7 @@ document_depth_percent_interval_type = "linear"
base_path = './data/needlebench'
file_list = ['PaulGrahamEssays.jsonl']
needlebench_datasets_en = []
needlebench_en_datasets = []
needle_file_name = 'needles.jsonl'
for original_context_length in context_lengths:
......@@ -83,10 +83,10 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_en.append(dataset_dict)
needlebench_en_datasets.append(dataset_dict)
file_list = ['zh_finance.jsonl']
needlebench_datasets_zh = []
needlebench_zh_datasets = []
needle_file_name = 'needles.jsonl'
for original_context_length in context_lengths:
......@@ -111,4 +111,4 @@ for original_context_length in context_lengths:
'infer_cfg': needlebench_infer_cfg,
'eval_cfg': needlebench_eval_cfg
}
needlebench_datasets_zh.append(dataset_dict)
needlebench_zh_datasets.append(dataset_dict)
......@@ -70,11 +70,19 @@ def get_config_from_arg(args) -> Config:
datasets = []
if args.datasets:
datasets_dir = os.path.join(args.config_dir, 'datasets')
for dataset in match_cfg_file(datasets_dir, args.datasets):
for dataset_arg in args.datasets:
if '/' in dataset_arg:
dataset_name, dataset_suffix = dataset_arg.split('/', 1)
dataset_key_suffix = dataset_suffix
else:
dataset_name = dataset_arg
dataset_key_suffix = '_datasets'
for dataset in match_cfg_file(datasets_dir, [dataset_name]):
get_logger().info(f'Loading {dataset[0]}: {dataset[1]}')
cfg = Config.fromfile(dataset[1])
for k in cfg.keys():
if k.endswith('_datasets'):
if k.endswith(dataset_key_suffix):
datasets += cfg[k]
else:
dataset = {'path': args.custom_dataset_path}
......@@ -119,12 +127,26 @@ def get_config_from_arg(args) -> Config:
run_cfg=dict(num_gpus=args.num_gpus))
models.append(model)
# parse summarizer args
summarizer = args.summarizer if args.summarizer is not None else 'example'
summarizer_arg = args.summarizer if args.summarizer is not None \
else 'example'
summarizers_dir = os.path.join(args.config_dir, 'summarizers')
s = match_cfg_file(summarizers_dir, [summarizer])[0]
# Check if summarizer_arg contains '/'
if '/' in summarizer_arg:
# If it contains '/', split the string by '/'
# and use the second part as the configuration key
summarizer_file, summarizer_key = summarizer_arg.split('/', 1)
else:
# If it does not contain '/', keep the original logic unchanged
summarizer_key = 'summarizer'
summarizer_file = summarizer_arg
s = match_cfg_file(summarizers_dir, [summarizer_file])[0]
get_logger().info(f'Loading {s[0]}: {s[1]}')
cfg = Config.fromfile(s[1])
summarizer = cfg['summarizer']
# Use summarizer_key to retrieve the summarizer definition
# from the configuration file
summarizer = cfg[summarizer_key]
return Config(dict(models=models, datasets=datasets,
summarizer=summarizer),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment