Unverified Commit 80a10075 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Add loncxt tasks (#2629)

suport for longcontext (and other synthetic tasks)
* add ruler
* add longbench
* pass `metadata` to TaskConfig
parent f47ddaf8
......@@ -72,6 +72,8 @@ This mode supports a number of command-line arguments, the details of which can
* `point_of_contact` - Point of contact for the results dataset, e.g., `yourname@example.com`.
* `gated` - whether to gate the details dataset, can be `True` or `False`.
* `--metadata`: JSON string to pass to TaskConfig. Used for some tasks which require additional metadata to be passed for processing. E.g., `--metadata '{"key": "value"}'`.
## External Library Usage
We also support using the library's external API for use within model training loops or other scripts.
......
......@@ -23,6 +23,7 @@ Dataset configuration options:
- **dataset_path** (`str`) — The name of the dataset as listed by HF in the datasets Hub.
- **dataset_name** (`str`, *optional*, defaults to None) — The name of what HF calls a “data instance” or sub-task of the benchmark. If your task does not contain any data instances, just leave this to default to None. (If you're familiar with the HF `datasets.load_dataset` function, these are just the first 2 arguments to it.)
- **dataset_kwargs** (`dict`, *optional*) — Auxiliary arguments that `datasets.load_dataset` accepts. This can be used to specify arguments such as `data_files` or `data_dir` if you want to use local datafiles such as json or csv.
- **custom_dataset** (`Callable`, *optional) - A function that returns a `dict[str, datasets.Dataset]` (<split_name>, dataset) object. This can be used to load a dataset from a custom source or to preprocess the dataset in a way that is not supported by the `datasets` library. Will have access to `metadata` field if defined (from config and passed to TaskManager), and `model_args` from runtime (if using `evaluate`).
- **training_split** (`str`, *optional*) — Split in the dataset to use as the training split.
- **validation_split** (`str`, *optional*) — Split in the dataset to use as the validation split.
- **test_split** (`str`, *optional*) — Split in the dataset to use as the test split.
......@@ -53,7 +54,7 @@ Scoring details:
- **doc_to_decontamination_query** (`str`, *optional*) — Query for decontamination if `should_decontaminate` is True. If `should_decontaminate` is True but `doc_to_decontamination_query` is `None`, `doc_to_decontamination_query` will follow `doc_to_text`.
Other:
- **metadata** (`dict`, *optional*) — An optional field where arbitrary metadata can be passed. Most tasks should include a `version` key in this field that is used to denote the version of the yaml config. Other special metadata keys are: `num_fewshot`, to override the printed `n-shot` table column for a task.
- **metadata** (`dict`, *optional*) — An optional field where arbitrary metadata can be passed. Most tasks should include a `version` key in this field that is used to denote the version of the yaml config. Other special metadata keys are: `num_fewshot`, to override the printed `n-shot` table column for a task. Will also be passed to the `custom_dataset` function if defined.
## Filters
......
......@@ -10,7 +10,24 @@ from lm_eval import evaluator, utils
from lm_eval.evaluator import request_caching_arg_to_dict
from lm_eval.loggers import EvaluationTracker, WandbLogger
from lm_eval.tasks import TaskManager
from lm_eval.utils import handle_non_serializable, make_table, simple_parse_args_string
from lm_eval.utils import (
handle_non_serializable,
make_table,
simple_parse_args_string,
)
def try_parse_json(value: str) -> Union[str, dict, None]:
if value is None:
return None
try:
return json.loads(value)
except json.JSONDecodeError:
if "{" in value:
raise argparse.ArgumentTypeError(
f"Invalid JSON: {value}. Hint: Use double quotes for JSON strings."
)
return value
def _int_or_none_list_arg_type(
......@@ -79,8 +96,8 @@ def setup_parser() -> argparse.ArgumentParser:
"--model_args",
"-a",
default="",
type=str,
help="Comma separated string arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`",
type=try_parse_json,
help="""Comma separated string or JSON formatted arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32` or '{"pretrained":"EleutherAI/pythia-160m","dtype":"float32"}'""",
)
parser.add_argument(
"--num_fewshot",
......@@ -202,11 +219,11 @@ def setup_parser() -> argparse.ArgumentParser:
)
parser.add_argument(
"--gen_kwargs",
type=str,
type=try_parse_json,
default=None,
help=(
"String arguments for model generation on greedy_until tasks,"
" e.g. `temperature=0,top_k=0,top_p=0`."
"Either comma delimited string or JSON formatted arguments for model generation on greedy_until tasks,"
""" e.g. '{"temperature":0.7,"until":["hello"]}' or temperature=0,top_p=0.1."""
),
)
parser.add_argument(
......@@ -262,6 +279,12 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="Confirm that you understand the risks of running unsafe code for tasks that require it",
)
parser.add_argument(
"--metadata",
type=json.loads,
default=None,
help="""JSON string metadata to pass to task configs, for example '{"max_seq_lengths":[4096,8192]}'. Will be merged with model_args. Can also be set in task config.""",
)
return parser
......@@ -305,7 +328,19 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if args.include_path is not None:
eval_logger.info(f"Including path: {args.include_path}")
task_manager = TaskManager(include_path=args.include_path)
metadata = (
simple_parse_args_string(args.model_args)
if isinstance(args.model_args, str)
else args.model_args
if isinstance(args.model_args, dict)
else {}
) | (
args.metadata
if isinstance(args.metadata, dict)
else simple_parse_args_string(args.metadata)
)
task_manager = TaskManager(include_path=args.include_path, metadata=metadata)
if "push_samples_to_hub" in evaluation_tracker_args and not args.log_samples:
eval_logger.warning(
......@@ -411,6 +446,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
torch_random_seed=args.seed[2],
fewshot_random_seed=args.seed[3],
confirm_run_unsafe_code=args.confirm_run_unsafe_code,
metadata=metadata,
**request_caching_args,
)
......
......@@ -60,6 +60,7 @@ class TaskConfig(dict):
# HF dataset options.
# which dataset to use,
# and what splits for what purpose
custom_dataset: Optional[Callable] = None
dataset_path: Optional[str] = None
dataset_name: Optional[str] = None
dataset_kwargs: Optional[dict] = None
......@@ -325,10 +326,11 @@ class Task(abc.ABC):
elif self.has_validation_docs():
return self.validation_docs()
else:
eval_logger.warning(
f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False"
", using test_docs as fewshot_docs but this is not recommended."
)
if self.config.get("num_fewshot", 0) > 0:
eval_logger.warning(
f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False"
", using test_docs as fewshot_docs but this is not recommended."
)
return self.test_docs()
def _process_doc(self, doc: dict) -> dict:
......@@ -938,12 +940,23 @@ class ConfigurableTask(Task):
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
)
def download(self, dataset_kwargs: Optional[Dict[str, Any]] = None) -> None:
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
**dataset_kwargs if dataset_kwargs is not None else {},
)
def download(
self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs
) -> None:
if isinstance(self.config.custom_dataset, Callable):
eval_logger.warning(
f"{self.config.task}: Custom kwargs can be passed to `--metadata` in console (as json string) or to the TaskManager."
+ "\nFor example --metadata='{\"max_seq_lengths\":[4096, 8192]}'. For details see task Readme."
)
self.dataset = self.config.custom_dataset(
**(self.config.metadata or {}), **(self.config.dataset_kwargs or {})
)
else:
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
**dataset_kwargs if dataset_kwargs is not None else {},
)
def has_training_docs(self) -> bool:
if self.config.training_split is not None:
......
......@@ -68,7 +68,7 @@ def simple_evaluate(
system_instruction: Optional[str] = None,
apply_chat_template: Union[bool, str] = False,
fewshot_as_multiturn: bool = False,
gen_kwargs: Optional[str] = None,
gen_kwargs: Union[str, dict, None] = None,
task_manager: Optional[TaskManager] = None,
verbosity=None,
predict_only: bool = False,
......@@ -77,6 +77,7 @@ def simple_evaluate(
torch_random_seed: int = 1234,
fewshot_random_seed: int = 1234,
confirm_run_unsafe_code: bool = False,
metadata: Optional[dict] = None,
):
"""Instantiate and evaluate a model on a list of tasks.
......@@ -100,9 +101,9 @@ def simple_evaluate(
:param cache_requests: bool, optional
Speed up evaluation by caching the building of dataset requests. `None` if not caching.
:param rewrite_requests_cache: bool, optional
Rewrites all of the request cache if set to `True`. `None` if not desired.
Rewrites all the request cache if set to `True`. `None` if not desired.
:param delete_requests_cache: bool, optional
Deletes all of the request cache if set to `True`. `None` if not desired.
Deletes all the request cache if set to `True`. `None` if not desired.
:param limit: int or float, optional
Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
:param bootstrap_iters:
......@@ -122,8 +123,8 @@ def simple_evaluate(
Defaults to False (no chat template applied).
:param fewshot_as_multiturn: bool
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
:param gen_kwargs: str
String arguments for model generation
:param gen_kwargs: dict or comma-separated string
Arguments for model generation
Ignored for all tasks with loglikelihood output_type
:param verbosity: str
Verbosity level for logging
......@@ -137,8 +138,10 @@ def simple_evaluate(
Random seed for torch. If set to None, the seed will not be set.
:param fewshot_random_seed: int
Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None.
:param metadata: dict
Additional metadata to be added to the task manager. Will get passed to the download function of the task.
:return
return
Dictionary of results
"""
if verbosity is not None:
......@@ -184,12 +187,13 @@ def simple_evaluate(
)
if gen_kwargs is not None:
gen_kwargs = simple_parse_args_string(gen_kwargs)
if isinstance(gen_kwargs, str):
gen_kwargs = simple_parse_args_string(gen_kwargs)
eval_logger.warning(
"generation_kwargs specified through cli, these settings will update set parameters in yaml tasks. "
f"generation_kwargs: {gen_kwargs} specified through cli, these settings will update set parameters in yaml tasks. "
"Ensure 'do_sample=True' for non-greedy decoding!"
)
if gen_kwargs == "":
if not gen_kwargs:
gen_kwargs = None
if isinstance(model, str):
......@@ -243,9 +247,19 @@ def simple_evaluate(
)
if task_manager is None:
task_manager = TaskManager()
task_dict = get_task_dict(tasks, task_manager)
metadata = (
simple_parse_args_string(model_args)
if isinstance(model_args, str)
else model_args
if isinstance(model_args, dict)
else {}
) | (metadata or {})
task_manager = TaskManager(metadata=metadata)
task_dict = get_task_dict(
tasks,
task_manager,
)
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
# (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
......@@ -264,6 +278,9 @@ def simple_evaluate(
task_obj.set_config(
key="generation_kwargs", value=gen_kwargs, update=True
)
eval_logger.info(
f"{task_obj.config.task}: Using gen_kwargs: {task_obj.config.generation_kwargs}"
)
if predict_only:
eval_logger.info(
......
......@@ -121,6 +121,7 @@
| [qasper](qasper/README.md) | Question Answering dataset based on academic papers, testing in-depth scientific knowledge. | English |
| [race](race/README.md) | Reading comprehension assessment tasks based on English exams in China. | English |
| realtoxicityprompts | Tasks to evaluate language models for generating text with potential toxicity. | |
| [ruler](ruler/README.md) | RULER is a benchmark for testing how well language models handle long pieces of text. Requires custom arg (see readme) | English |
| [sciq](sciq/README.md) | Science Question Answering tasks to assess understanding of scientific concepts. | English |
| [score](score/README.md) | Systematic consistency and robustness evaluation for LLMs on 3 datasets(MMLU-Pro, Agi Eval and MATH) | English |
| [scrolls](scrolls/README.md) | Tasks that involve long-form reading comprehension across various domains. | English |
......
......@@ -27,11 +27,12 @@ class TaskManager:
verbosity: Optional[str] = None,
include_path: Optional[Union[str, List]] = None,
include_defaults: bool = True,
metadata: Optional[dict] = None,
) -> None:
if verbosity is not None:
utils.setup_logging(verbosity)
self.include_path = include_path
self.metadata = metadata
self._task_index = self.initialize_tasks(
include_path=include_path, include_defaults=include_defaults
)
......@@ -57,15 +58,15 @@ class TaskManager:
self,
include_path: Optional[Union[str, List]] = None,
include_defaults: bool = True,
):
"""Creates a dictionary of tasks index.
) -> dict[str, dict]:
"""Creates a dictionary of tasks indexes.
:param include_path: Union[str, List] = None
An additional path to be searched for tasks recursively.
Can provide more than one such path as a list.
:param include_defaults: bool = True
If set to false, default tasks (those in lm_eval/tasks/) are not indexed.
:return
return
Dictionary of task names as key and task metadata
"""
if include_defaults:
......@@ -170,54 +171,54 @@ class TaskManager:
result += subtask_table.dumps() + "\n\n"
return result
def match_tasks(self, task_list):
def match_tasks(self, task_list: list[str]) -> list[str]:
return utils.pattern_match(task_list, self.all_tasks)
def _name_is_registered(self, name) -> bool:
def _name_is_registered(self, name: str) -> bool:
if name in self.all_tasks:
return True
return False
def _name_is_task(self, name) -> bool:
def _name_is_task(self, name: str) -> bool:
if self._name_is_registered(name) and (self.task_index[name]["type"] == "task"):
return True
return False
def _name_is_tag(self, name) -> bool:
def _name_is_tag(self, name: str) -> bool:
if self._name_is_registered(name) and (self.task_index[name]["type"] == "tag"):
return True
return False
def _name_is_group(self, name) -> bool:
def _name_is_group(self, name: str) -> bool:
if self._name_is_registered(name) and (
self.task_index[name]["type"] == "group"
):
return True
return False
def _name_is_python_task(self, name):
def _name_is_python_task(self, name: str) -> bool:
if self._name_is_registered(name) and (
self.task_index[name]["type"] == "python_task"
):
return True
return False
def _config_is_task(self, config) -> bool:
def _config_is_task(self, config: dict) -> bool:
if ("task" in config) and isinstance(config["task"], str):
return True
return False
def _config_is_group(self, config) -> bool:
def _config_is_group(self, config: dict) -> bool:
if ("task" in config) and isinstance(config["task"], list):
return True
return False
def _config_is_python_task(self, config) -> bool:
def _config_is_python_task(self, config: dict) -> bool:
if "class" in config:
return True
return False
def _get_yaml_path(self, name):
def _get_yaml_path(self, name: str):
if name not in self.task_index:
raise ValueError
return self.task_index[name]["yaml_path"]
......@@ -278,11 +279,19 @@ class TaskManager:
# very scuffed: set task name here. TODO: fixme?
task_object.config.task = task
else:
if self.metadata is not None:
config["metadata"] = config.get("metadata", {}) | self.metadata
else:
config["metadata"] = config.get("metadata", {})
task_object = ConfigurableTask(config=config)
return {task: task_object}
def _get_group_and_subtask_from_config(config):
def _get_group_and_subtask_from_config(
config: dict,
) -> tuple[ConfigurableGroup, list[str]]:
if self.metadata is not None:
config["metadata"] = config.get("metadata", {}) | self.metadata
group_name = ConfigurableGroup(config=config)
subtask_list = []
for task in group_name.config["task"]:
......@@ -292,7 +301,9 @@ class TaskManager:
subtask_list.append(task)
return group_name, subtask_list
def _process_group_config(config, update_config=None):
def _process_group_config(
config: dict, update_config: dict = None
) -> tuple[dict, dict]:
if update_config is not None:
config = {**config, **update_config}
_update_config = {
......@@ -412,7 +423,12 @@ class TaskManager:
task_list = [task_list]
all_loaded_tasks = dict(
collections.ChainMap(*map(self._load_individual_task_or_group, task_list))
collections.ChainMap(
*map(
lambda task: self._load_individual_task_or_group(task),
task_list,
)
)
)
return all_loaded_tasks
......@@ -547,7 +563,7 @@ def get_task_name_from_object(task_object):
)
def _check_duplicates(task_dict: dict) -> List[str]:
def _check_duplicates(task_dict: dict) -> None:
"""helper function solely used in validating get_task_dict output.
Takes the output of lm_eval.evaluator_utils.get_subtask_list and
returns a list of all leaf subtasks contained within, and errors if any such leaf subtasks are
......
......@@ -147,4 +147,4 @@ If other tasks on this dataset are already supported:
### Changelog
version 2.0: (2025-Mar-18) add [`cococteros_va`](./cocoteros_va.yaml) task.
\ No newline at end of file
version 2.0: (2025-Mar-18) add [`cococteros_va`](./cocoteros_va.yaml) task.
tag:
- longbench
task: longbench_2wikimqa
dataset_path: THUDM/LongBench
test_split: test
dataset_name: 2wikimqa
doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:'
doc_to_target: '{{answers}}'
generation_kwargs:
max_gen_toks: 32
temperature: 1
do_sample: True
metric_list:
- metric: !function metrics.qa_f1_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
tag:
- longbench_e
task: longbench_2wikimqa_e
dataset_path: THUDM/LongBench
test_split: test
dataset_name: 2wikimqa_e
doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:'
doc_to_target: '{{answers}}'
generation_kwargs:
max_gen_toks: 32
temperature: 1
do_sample: True
metric_list:
- metric: !function metrics.qa_f1_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
# Task-name
### Paper
Title: `LongBench v2: Towards Deeper Understanding and Reasoning on Realistic Long-context Multitasks`
Abstract: `This paper introduces LongBench v2, a benchmark designed to assess the ability of LLMs to handle long-context problems requiring deep understanding and reasoning across real-world multitasks. LongBench v2 consists of 503 challenging multiple-choice questions, with contexts ranging from 8k to 2M words, across six major task categories: single-document QA, multi-document QA, long in-context learning, long-dialogue history understanding, code repository understanding, and long structured data understanding.`
Homepage: `https://github.com/THUDM/LongBench`
### Citation
```
@article{bai2024longbench2,
title={LongBench v2: Towards Deeper Understanding and Reasoning on Realistic Long-context Multitasks},
author={Yushi Bai and Shangqing Tu and Jiajie Zhang and Hao Peng and Xiaozhi Wang and Xin Lv and Shulin Cao and Jiazheng Xu and Lei Hou and Yuxiao Dong and Jie Tang and Juanzi Li},
journal={arXiv preprint arXiv:2412.15204},
year={2024}
}
@inproceedings{bai2024longbench,
title = "{L}ong{B}ench: A Bilingual, Multitask Benchmark for Long Context Understanding",
author = "Bai, Yushi and Lv, Xin and Zhang, Jiajie and Lyu, Hongchang and
Tang, Jiankai and Huang, Zhidian and Du, Zhengxiao and Liu, Xiao and Zeng, Aohan and Hou, Lei and Dong, Yuxiao and Tang, Jie and Li, Juanzi",
booktitle = "Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
month = aug,
year = "2024",
address = "Bangkok, Thailand",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2024.acl-long.172",
doi = "10.18653/v1/2024.acl-long.172",
pages = "3119--3137",
}
```
### Groups, Tags, and Tasks
#### Groups
[//]: # (* `group_name`: `Short description`)
#### Tags
* `LongBench`: `Benchmark with 21 tasks (avg. 5k-15k tokens) for evaluating long-context capabilities`
* `LongBench-E`: `Modified version with uniform length distribution (0-4k, 4k-8k, 8k+) for analyzing performance across different input lengths`
#### Tasks
* `2wikimqa`: `Question answering task using multiple Wikipedia articles as reference`
* `2wikimqa_e`: `Extended version of 2wikimqa with additional complexity or data`
* `dureader`: `Chinese machine reading comprehension dataset with real-world queries`
* `gov_report`: `Summarization task for long government reports and documents`
* `gov_report_e`: `Extended version of gov_report with additional complexity or data`
* `hotpotqa`: `Multi-hop question answering requiring reasoning across multiple paragraphs`
* `hotpotqa_e`: `Extended version of hotpotqa with additional complexity or data`
* `lcc`: `Long-form content classification across various categories and domains`
* `lcc_e`: `Extended version of lcc with additional complexity or data`
* `lsht`: `Large-scale hierarchical text classification task`
* `multi_news`: `Multi-document news summarization task`
* `multi_news_e`: `Extended version of multi_news with additional complexity or data`
* `multifieldqa_en`: `English question answering across multiple knowledge domains or fields`
* `multifieldqa_en_e`: `Extended version of multifieldqa_en with additional complexity or data`
* `multifieldqa_zh`: `Chinese question answering across multiple knowledge domains or fields`
* `musique`: `Multi-step reasoning question answering with complex queries`
* `narrativeqa`: `Question answering based on book and movie narratives`
* `passage_count`: `Task requiring counting or quantifying information across passages`
* `passage_count_e`: `Extended version of passage_count with additional complexity or data`
* `passage_retrieval_en`: `English passage retrieval task for information seeking`
* `passage_retrieval_en_e`: `Extended version of passage_retrieval_en with additional complexity or data`
* `passage_retrieval_zh`: `Chinese passage retrieval task for information seeking`
* `qasper`: `Question answering on scientific papers requiring domain knowledge`
* `qasper_e`: `Extended version of qasper with additional complexity or data`
* `qmsum`: `Query-based meeting summarization task`
* `repobench-p`: `Programming task based on code repositories`
* `repobench-p_e`: `Extended version of repobench-p with additional complexity or data`
* `samsum`: `Dialogue summarization for messenger-like conversations`
* `samsum_e`: `Extended version of samsum with additional complexity or data`
* `trec`: `Question classification task for information retrieval`
* `trec_e`: `Extended version of trec with additional complexity or data`
* `triviaqa`: `Large-scale question answering dataset with trivia questions`
* `triviaqa_e`: `Extended version of triviaqa with additional complexity or data`
* `vcsum`: `Video conference summarization task`
### Checklist
For adding novel benchmarks/datasets to the library:
* [x] Is the task an existing benchmark in the literature?
* [x] Have you referenced the original paper that introduced the task?
* [x] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
If other tasks on this dataset are already supported:
* [x] Is the "Main" variant of this task clearly denoted?
* [x] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [x] Have you noted which, if any, published evaluation setups are matched by this variant?
### Changelog
# MIT License
#
# Copyright (c) 2023 THU-KEG & Zhipu AI
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import argparse
from jinja2 import Environment
dataset2maxlen = {
"narrativeqa": 128,
"qasper": 128,
"multifieldqa_en": 64,
"multifieldqa_zh": 64,
"hotpotqa": 32,
"2wikimqa": 32,
"musique": 32,
"dureader": 128,
"gov_report": 512,
"qmsum": 512,
"multi_news": 512,
"vcsum": 512,
"trec": 64,
"triviaqa": 32,
"samsum": 128,
"lsht": 64,
"passage_count": 32,
"passage_retrieval_en": 32,
"passage_retrieval_zh": 32,
"lcc": 64,
"repobench-p": 64,
}
dataset2prompt = {
"narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
"qasper": 'You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write "unanswerable". If the question is a yes/no question, answer "yes", "no", or "unanswerable". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write "unanswerable". If the question is a yes/no question, answer "yes", "no", or "unanswerable". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:',
"multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
"multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:",
"hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
"2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
"musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
"dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:",
"gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:",
"qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:",
"multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:",
"vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:",
"trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}",
"triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}",
"samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}",
"lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}",
"passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ",
"passage_retrieval_en": 'Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like "Paragraph 1", "Paragraph 2", etc.\n\nThe answer is: ',
"passage_retrieval_zh": '以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是"段落1","段落2"等格式\n\n答案是:',
"lcc": "Please complete the code given below. \n{context}Next line of code:\n",
"repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n",
}
dataset2metric = {
"narrativeqa": "qa_f1_score",
"qasper": "qa_f1_score",
"multifieldqa_en": "qa_f1_score",
"multifieldqa_zh": "qa_f1_zh_score",
"hotpotqa": "qa_f1_score",
"2wikimqa": "qa_f1_score",
"musique": "qa_f1_score",
"dureader": "rouge_zh_score",
"gov_report": "rouge_score",
"qmsum": "rouge_score",
"multi_news": "rouge_score",
"vcsum": "rouge_zh_score",
"trec": "classification_score",
"triviaqa": "qa_f1_score",
"samsum": "rouge_score",
"lsht": "classification_score",
"passage_retrieval_en": "retrieval_score",
"passage_count": "count_score",
"passage_retrieval_zh": "retrieval_zh_score",
"lcc": "code_sim_score",
"repobench-p": "code_sim_score",
}
DATASETS = [
"2wikimqa",
"2wikimqa_e",
"dureader",
"gov_report",
"gov_report_e",
"hotpotqa",
"hotpotqa_e",
"lcc",
"lcc_e",
"lsht",
"multi_news",
"multi_news_e",
"multifieldqa_en",
"multifieldqa_en_e",
"multifieldqa_zh",
"musique",
"narrativeqa",
"passage_count",
"passage_count_e",
"passage_retrieval_en",
"passage_retrieval_en_e",
"passage_retrieval_zh",
"qasper",
"qasper_e",
"qmsum",
"repobench-p",
"repobench-p_e",
"samsum",
"samsum_e",
"trec",
"trec_e",
"triviaqa",
"triviaqa_e",
"vcsum",
]
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--save_prefix_path", default="longbench")
return parser.parse_args()
# Create template string
template_str = """
tag:
- {{ tag[0] }}
task: {{ task }}
dataset_path: {{ dataset_path }}
test_split: {{ test_split }}
dataset_name: {{ dataset_name }}
doc_to_text: '{{ doc_to_text }}'
doc_to_target: '{{ doc_to_target }}'
generation_kwargs:
max_gen_toks: {{ generation_kwargs.max_gen_toks }}
temperature: {{ generation_kwargs.temperature }}
do_sample: {{ generation_kwargs.do_sample }}
metric_list:
- metric: {{ metric_list[0].metric }}
aggregation: {{ metric_list[0].aggregation }}
higher_is_better: {{ metric_list[0].higher_is_better }}
metadata:
version: {{ metadata.version }}
"""
if __name__ == "__main__":
args = parse_args()
env = Environment()
template = env.from_string(template_str)
for ds in DATASETS:
df = ds[:-2] if ds.endswith("_e") else ds
generation_kwargs = {
"max_gen_toks": dataset2maxlen[df],
"temperature": 1,
"do_sample": True,
}
raw_doc_to_text = (
dataset2prompt[df]
.replace("\n", "\\n")
.replace("{", "{{")
.replace("}", "}}")
)
metric_list = [
{
"metric": f"!function metrics.{dataset2metric[df]}",
"aggregation": "mean",
"higher_is_better": True,
}
]
data = {
"tag": [
"longbench_e" if ds.endswith("_e") else "longbench"
], # Now properly as a list
"task": f"longbench_{ds}",
"dataset_path": "THUDM/LongBench",
"test_split": "test",
"dataset_name": ds,
"doc_to_text": raw_doc_to_text,
"doc_to_target": "{{answers}}",
"generation_kwargs": generation_kwargs,
"metric_list": metric_list,
"metadata": {"version": "1.0"},
}
# Render template
rendered_yaml = template.render(**data)
# Save to file
with open(args.save_prefix_path + f"{ds}.yaml", "w") as f:
f.write(rendered_yaml)
# for ds in DATASETS:
# df = ds[:-2] if ds.endswith("_e") else ds
# generation_kwargs = {"max_gen_toks": dataset2maxlen[df], "temperature": 1, "do_sample": False}
# # Escape newlines and curly braces
# raw_doc_to_text = dataset2prompt[df].replace("\n", "\\n").replace("{", "{{").replace("}", "}}")
# metric_list = [
# {"metric": f"!function metrics.{dataset2metric[df]}", "aggregation": "mean", "higher_is_better": True}]
# yaml_dict = {
# "tag": ["longbench_e" if ds.endswith("_e") else "longbench"],
# "task": f"longbench_{ds}",
# "dataset_path": "THUDM/LongBench",
# "test_split": "test",
# "dataset_name": ds,
# "doc_to_text": raw_doc_to_text,
# "doc_to_target": "{{answers}}",
# "generation_kwargs": generation_kwargs,
# "metric_list": metric_list,
# "metadata": {"version": "1.0"}
# }
# template = env.from_string(yaml_dict)
#
#
# file_save_path = args.save_prefix_path + f"{ds}.yaml"
# with open(file_save_path, "w", encoding="utf-8") as yaml_file:
# yaml.dump(
# yaml_dict,
# yaml_file,
# allow_unicode=True,
# default_flow_style=False,
# sort_keys=False
# )
tag:
- longbench
task: longbench_dureader
dataset_path: THUDM/LongBench
test_split: test
dataset_name: dureader
doc_to_text: '请基于给定的文章回答下述问题。\n\n文章:{{context}}\n\n请基于上述文章回答下面的问题。\n\n问题:{{input}}\n回答:'
doc_to_target: '{{answers}}'
generation_kwargs:
max_gen_toks: 128
temperature: 1
do_sample: True
metric_list:
- metric: !function metrics.rouge_zh_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
tag:
- longbench
task: longbench_gov_report
dataset_path: THUDM/LongBench
test_split: test
dataset_name: gov_report
doc_to_text: 'You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{{context}}\n\nNow, write a one-page summary of the report.\n\nSummary:'
doc_to_target: '{{answers}}'
generation_kwargs:
max_gen_toks: 512
temperature: 1
do_sample: True
metric_list:
- metric: !function metrics.rouge_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
tag:
- longbench_e
task: longbench_gov_report_e
dataset_path: THUDM/LongBench
test_split: test
dataset_name: gov_report_e
doc_to_text: 'You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{{context}}\n\nNow, write a one-page summary of the report.\n\nSummary:'
doc_to_target: '{{answers}}'
generation_kwargs:
max_gen_toks: 512
temperature: 1
do_sample: True
metric_list:
- metric: !function metrics.rouge_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
tag:
- longbench
task: longbench_hotpotqa
dataset_path: THUDM/LongBench
test_split: test
dataset_name: hotpotqa
doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:'
doc_to_target: '{{answers}}'
generation_kwargs:
max_gen_toks: 32
temperature: 1
do_sample: True
metric_list:
- metric: !function metrics.qa_f1_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
tag:
- longbench_e
task: longbench_hotpotqa_e
dataset_path: THUDM/LongBench
test_split: test
dataset_name: hotpotqa_e
doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:'
doc_to_target: '{{answers}}'
generation_kwargs:
max_gen_toks: 32
temperature: 1
do_sample: True
metric_list:
- metric: !function metrics.qa_f1_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
tag:
- longbench
task: longbench_lcc
dataset_path: THUDM/LongBench
test_split: test
dataset_name: lcc
doc_to_text: 'Please complete the code given below. \n{{context}}Next line of code:\n'
doc_to_target: '{{answers}}'
generation_kwargs:
max_gen_toks: 64
temperature: 1
do_sample: True
metric_list:
- metric: !function metrics.code_sim_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
tag:
- longbench_e
task: longbench_lcc_e
dataset_path: THUDM/LongBench
test_split: test
dataset_name: lcc_e
doc_to_text: 'Please complete the code given below. \n{{context}}Next line of code:\n'
doc_to_target: '{{answers}}'
generation_kwargs:
max_gen_toks: 64
temperature: 1
do_sample: True
metric_list:
- metric: !function metrics.code_sim_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
tag:
- longbench
task: longbench_lsht
dataset_path: THUDM/LongBench
test_split: test
dataset_name: lsht
doc_to_text: '请判断给定新闻的类别,下面是一些例子。\n\n{{context}}\n{{input}}'
doc_to_target: '{{answers}}'
generation_kwargs:
max_gen_toks: 64
temperature: 1
do_sample: True
metric_list:
- metric: !function metrics.classification_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment