Unverified Commit 930d8378 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Longbench bugfix (#2895)

* add warning in for default until

* fix stop tokens; add vcsum

* bugfix:fix doc_to_target to string

* fix lsht, trec

* add task to readme

* add debugging logs for multiple input/output
parent 82fe48ec
......@@ -113,6 +113,9 @@ class TaskConfig(dict):
)
if "until" not in self.generation_kwargs:
eval_logger.warning(
f"{self.task}: No `until` specified in `generation_kwargs`! Defaulting to the fewshot_delimiter={repr(self.fewshot_delimiter)}"
)
self.generation_kwargs["until"] = [self.fewshot_delimiter]
else:
if self.output_type == "generate_until":
......@@ -124,7 +127,11 @@ class TaskConfig(dict):
else [self.fewshot_delimiter]
),
"do_sample": False,
"temperature": 0,
}
eval_logger.warning(
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
)
def __getitem__(self, item):
return getattr(self, item)
......@@ -928,11 +935,17 @@ class ConfigurableTask(Task):
num_choice = len(test_choice)
if isinstance(test_text, int):
eval_logger.debug(
"doc_to_text returned an int. Assuming multiple inputs."
)
self.multiple_input = num_choice
else:
test_choice = None
if isinstance(test_target, list):
eval_logger.debug(
"doc_to_target returned a list. Assuming multiple targets."
)
self.multiple_target = len(test_target)
else:
if (isinstance(test_target, int)) and (test_choice is not None):
......
This diff is collapsed.
......@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test
dataset_name: 2wikimqa
doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:'
doc_to_target: '{{answers}}'
doc_to_target: '{{answers[0]}}'
generation_kwargs:
max_gen_toks: 32
temperature: 1
do_sample: True
until: []
metric_list:
- metric: !function metrics.qa_f1_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
version: 2.0
tag:
- longbench_e
task: longbench_2wikimqa_e
......@@ -5,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test
dataset_name: 2wikimqa_e
doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:'
doc_to_target: '{{answers}}'
doc_to_target: '{{answers[0]}}'
generation_kwargs:
max_gen_toks: 32
temperature: 1
do_sample: True
until: []
metric_list:
- metric: !function metrics.qa_f1_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
version: 2.0
......@@ -95,3 +95,4 @@ If other tasks on this dataset are already supported:
* [x] Have you noted which, if any, published evaluation setups are matched by this variant?
### Changelog
v2.: fix doc_to_target; add vcsum
......@@ -138,7 +138,7 @@ DATASETS = [
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--save_prefix_path", default="longbench")
parser.add_argument("--save_prefix_path", default="")
return parser.parse_args()
......@@ -156,6 +156,7 @@ generation_kwargs:
max_gen_toks: {{ generation_kwargs.max_gen_toks }}
temperature: {{ generation_kwargs.temperature }}
do_sample: {{ generation_kwargs.do_sample }}
until: {{ generation_kwargs.until }}
metric_list:
- metric: {{ metric_list[0].metric }}
aggregation: {{ metric_list[0].aggregation }}
......@@ -171,10 +172,21 @@ if __name__ == "__main__":
template = env.from_string(template_str)
for ds in DATASETS:
df = ds[:-2] if ds.endswith("_e") else ds
# from https://github.com/THUDM/LongBench/blob/2e00731f8d0bff23dc4325161044d0ed8af94c1e/LongBench/eval.py#L52C25-L52C29
if df in ["trec", "triviaqa", "samsum", "lsht"] + [
"trec_e",
"triviaqa_e",
"samsum_e",
"lsht_e",
]:
until = ["\n"]
else:
until = []
generation_kwargs = {
"max_gen_toks": dataset2maxlen[df],
"temperature": 1,
"do_sample": True,
"until": until,
}
raw_doc_to_text = (
dataset2prompt[df]
......@@ -199,10 +211,10 @@ if __name__ == "__main__":
"test_split": "test",
"dataset_name": ds,
"doc_to_text": raw_doc_to_text,
"doc_to_target": "{{answers}}",
"doc_to_target": "{{answers[0]}}",
"generation_kwargs": generation_kwargs,
"metric_list": metric_list,
"metadata": {"version": "1.0"},
"metadata": {"version": "2.0"},
}
# Render template
......@@ -211,35 +223,3 @@ if __name__ == "__main__":
# Save to file
with open(args.save_prefix_path + f"{ds}.yaml", "w") as f:
f.write(rendered_yaml)
# for ds in DATASETS:
# df = ds[:-2] if ds.endswith("_e") else ds
# generation_kwargs = {"max_gen_toks": dataset2maxlen[df], "temperature": 1, "do_sample": False}
# # Escape newlines and curly braces
# raw_doc_to_text = dataset2prompt[df].replace("\n", "\\n").replace("{", "{{").replace("}", "}}")
# metric_list = [
# {"metric": f"!function metrics.{dataset2metric[df]}", "aggregation": "mean", "higher_is_better": True}]
# yaml_dict = {
# "tag": ["longbench_e" if ds.endswith("_e") else "longbench"],
# "task": f"longbench_{ds}",
# "dataset_path": "THUDM/LongBench",
# "test_split": "test",
# "dataset_name": ds,
# "doc_to_text": raw_doc_to_text,
# "doc_to_target": "{{answers}}",
# "generation_kwargs": generation_kwargs,
# "metric_list": metric_list,
# "metadata": {"version": "1.0"}
# }
# template = env.from_string(yaml_dict)
#
#
# file_save_path = args.save_prefix_path + f"{ds}.yaml"
# with open(file_save_path, "w", encoding="utf-8") as yaml_file:
# yaml.dump(
# yaml_dict,
# yaml_file,
# allow_unicode=True,
# default_flow_style=False,
# sort_keys=False
# )
......@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test
dataset_name: dureader
doc_to_text: '请基于给定的文章回答下述问题。\n\n文章:{{context}}\n\n请基于上述文章回答下面的问题。\n\n问题:{{input}}\n回答:'
doc_to_target: '{{answers}}'
doc_to_target: '{{answers[0]}}'
generation_kwargs:
max_gen_toks: 128
temperature: 1
do_sample: True
until: []
metric_list:
- metric: !function metrics.rouge_zh_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
version: 2.0
......@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test
dataset_name: gov_report
doc_to_text: 'You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{{context}}\n\nNow, write a one-page summary of the report.\n\nSummary:'
doc_to_target: '{{answers}}'
doc_to_target: '{{answers[0]}}'
generation_kwargs:
max_gen_toks: 512
temperature: 1
do_sample: True
until: []
metric_list:
- metric: !function metrics.rouge_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
version: 2.0
......@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test
dataset_name: gov_report_e
doc_to_text: 'You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{{context}}\n\nNow, write a one-page summary of the report.\n\nSummary:'
doc_to_target: '{{answers}}'
doc_to_target: '{{answers[0]}}'
generation_kwargs:
max_gen_toks: 512
temperature: 1
do_sample: True
until: []
metric_list:
- metric: !function metrics.rouge_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
version: 2.0
......@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test
dataset_name: hotpotqa
doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:'
doc_to_target: '{{answers}}'
doc_to_target: '{{answers[0]}}'
generation_kwargs:
max_gen_toks: 32
temperature: 1
do_sample: True
until: []
metric_list:
- metric: !function metrics.qa_f1_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
version: 2.0
......@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test
dataset_name: hotpotqa_e
doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:'
doc_to_target: '{{answers}}'
doc_to_target: '{{answers[0]}}'
generation_kwargs:
max_gen_toks: 32
temperature: 1
do_sample: True
until: []
metric_list:
- metric: !function metrics.qa_f1_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
version: 2.0
......@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test
dataset_name: lcc
doc_to_text: 'Please complete the code given below. \n{{context}}Next line of code:\n'
doc_to_target: '{{answers}}'
doc_to_target: '{{answers[0]}}'
generation_kwargs:
max_gen_toks: 64
temperature: 1
do_sample: True
until: []
metric_list:
- metric: !function metrics.code_sim_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
version: 2.0
......@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test
dataset_name: lcc_e
doc_to_text: 'Please complete the code given below. \n{{context}}Next line of code:\n'
doc_to_target: '{{answers}}'
doc_to_target: '{{answers[0]}}'
generation_kwargs:
max_gen_toks: 64
temperature: 1
do_sample: True
until: []
metric_list:
- metric: !function metrics.code_sim_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
version: 2.0
......@@ -6,14 +6,16 @@ dataset_path: THUDM/LongBench
test_split: test
dataset_name: lsht
doc_to_text: '请判断给定新闻的类别,下面是一些例子。\n\n{{context}}\n{{input}}'
doc_to_target: '{{answers}}'
doc_to_target: '{{answers[0]}}'
process_results: !function metrics.classification_score
generation_kwargs:
max_gen_toks: 64
temperature: 1
do_sample: True
until: ['\n']
metric_list:
- metric: !function metrics.classification_score
- metric: "classification_score"
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
version: 2.0
......@@ -124,12 +124,10 @@ def code_sim_score(predictions: list[str], references: list[str], **kwargs) -> f
return fuzz.ratio(prediction, ground_truth) / 100
def classification_score(
predictions: list[str], references: list[str], **kwargs
) -> float:
prediction, ground_truth = predictions[0], references[0]
def classification_score(doc: dict, results: list[str], **kwargs) -> dict:
prediction, ground_truth = results[0], doc["answers"][0]
em_match_list = []
all_classes = kwargs["all_classes"]
all_classes = doc["all_classes"]
for class_name in all_classes:
if class_name in prediction:
em_match_list.append(class_name)
......@@ -140,12 +138,14 @@ def classification_score(
score = 1.0 / len(em_match_list)
else:
score = 0.0
return score
return {"classification_score": score}
def rouge_score(predictions: list[str], references: list[str], **kwargs) -> float:
global rouge
if "rouge" not in globals():
rouge = Rouge()
prediction, ground_truth = predictions[0], references[0]
rouge = Rouge()
try:
scores = rouge.get_scores([prediction], [ground_truth], avg=True)
# ruff: noqa
......@@ -162,7 +162,7 @@ def rouge_zh_score(predictions: list[str], references: list[str], **kwargs) -> f
return score
def f1_score(predictions: list[str], references: list[str], **kwargs):
def f1_score(predictions: list[str], references: list[str], **kwargs) -> float:
try:
prediction, ground_truth = predictions[0], references[0]
except:
......
......@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test
dataset_name: multi_news
doc_to_text: 'You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{{context}}\n\nNow, write a one-page summary of all the news.\n\nSummary:'
doc_to_target: '{{answers}}'
doc_to_target: '{{answers[0]}}'
generation_kwargs:
max_gen_toks: 512
temperature: 1
do_sample: True
until: []
metric_list:
- metric: !function metrics.rouge_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
version: 2.0
......@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test
dataset_name: multi_news_e
doc_to_text: 'You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{{context}}\n\nNow, write a one-page summary of all the news.\n\nSummary:'
doc_to_target: '{{answers}}'
doc_to_target: '{{answers[0]}}'
generation_kwargs:
max_gen_toks: 512
temperature: 1
do_sample: True
until: []
metric_list:
- metric: !function metrics.rouge_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
version: 2.0
......@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test
dataset_name: multifieldqa_en
doc_to_text: 'Read the following text and answer briefly.\n\n{{context}}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:'
doc_to_target: '{{answers}}'
doc_to_target: '{{answers[0]}}'
generation_kwargs:
max_gen_toks: 64
temperature: 1
do_sample: True
until: []
metric_list:
- metric: !function metrics.qa_f1_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
version: 2.0
......@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test
dataset_name: multifieldqa_en_e
doc_to_text: 'Read the following text and answer briefly.\n\n{{context}}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:'
doc_to_target: '{{answers}}'
doc_to_target: '{{answers[0]}}'
generation_kwargs:
max_gen_toks: 64
temperature: 1
do_sample: True
until: []
metric_list:
- metric: !function metrics.qa_f1_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
version: 2.0
......@@ -6,14 +6,15 @@ dataset_path: THUDM/LongBench
test_split: test
dataset_name: multifieldqa_zh
doc_to_text: '阅读以下文字并用中文简短回答:\n\n{{context}}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{{input}}\n回答:'
doc_to_target: '{{answers}}'
doc_to_target: '{{answers[0]}}'
generation_kwargs:
max_gen_toks: 64
temperature: 1
do_sample: True
until: []
metric_list:
- metric: !function metrics.qa_f1_zh_score
aggregation: mean
higher_is_better: True
metadata:
version: 1.0
version: 2.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment