Unverified Commit 147e9d61 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

[longbench] fix metric calculation (#2983)

* use all answers

* use middle truncation

* maybe fix classification score

* strip classification preds

* [vllm] remove stop tokens post-hoc

* strip all preds

* pacify pre-commit

* start on truncation utility

* add to readme

* add a footgun doc

* fix newline in yaml templates

* do not strip code_sim preds!

* fix pre-commit config

* fix instruction warning

* add not to longbench readme
parent 9f152e0b
...@@ -29,7 +29,7 @@ repos: ...@@ -29,7 +29,7 @@ repos:
- id: mixed-line-ending - id: mixed-line-ending
args: [--fix=lf] args: [--fix=lf]
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.0 rev: v0.11.10
hooks: hooks:
# Run the linter. # Run the linter.
- id: ruff - id: ruff
...@@ -50,7 +50,7 @@ repos: ...@@ -50,7 +50,7 @@ repos:
rev: v0.9.29 rev: v0.9.29
hooks: hooks:
- id: pymarkdown - id: pymarkdown
exclude: ^lm_eval/tasks/ exclude: ^(lm_eval/tasks/.*|docs/footguns\.md)$
args: [fix, -r] args: [fix, -r]
# - repo: https://github.com/pre-commit/mirrors-mypy # - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.5.1 # rev: v1.5.1
......
# Common Pitfalls and Troubleshooting Guide
This document highlights common pitfalls and troubleshooting tips when using this library. We'll continue to add more tips as we discover them.
## YAML Configuration Issues
### Newline Characters in YAML (`\n`)
**Problem:** When specifying newline characters in YAML, they may be interpreted incorrectly depending on how you format them.
```yaml
# ❌ WRONG: Single quotes don't process escape sequences
generation_kwargs:
until: ['\n'] # Gets parsed as the literal characters '\' and 'n' i.e "\\n"
```
```yaml
# ✅ RIGHT: Use double quotes for escape sequences
generation_kwargs:
until: ["\n"] # Gets parsed as an actual newline character
```
**Solutions:**
- Use double quotes for strings containing escape sequences
- For multiline content, use YAML's block scalars (`|` or `>`)
- When generating YAML programmatically, be careful with how template engines handle escape sequences
### Quoting in YAML
**When to use different types of quotes:**
- **No quotes**: Simple values (numbers, booleans, alphanumeric strings without special characters)
```yaml
simple_value: plain text
number: 42
```
- **Single quotes (')**:
- Preserves literal values
- Use when you need special characters to be treated literally
- Escape single quotes by doubling them: `'It''s working'`
```yaml
literal_string: 'The newline character \n is not processed here'
path: 'C:\Users\name' # Backslashes preserved
```
- **Double quotes (")**:
- Processes escape sequences like `\n`, `\t`, etc.
- Use for strings that need special characters interpreted
- Escape double quotes with backslash: `"He said \"Hello\""`
```yaml
processed_string: "First line\nSecond line" # Creates actual newline
unicode: "Copyright symbol: \u00A9" # Unicode character
```
...@@ -153,11 +153,15 @@ def simple_evaluate( ...@@ -153,11 +153,15 @@ def simple_evaluate(
"Either 'limit' or 'samples' must be None, but both are not None." "Either 'limit' or 'samples' must be None, but both are not None."
) )
if isinstance(model_args, str) and ( if (
"instruct" in model_args and not apply_chat_template (isinstance(model_args, str) and "inst" in model_args.lower())
): or (
isinstance(model_args, dict)
and any("inst" in str(v).lower() for v in model_args.values())
)
) and not apply_chat_template:
eval_logger.warning( eval_logger.warning(
"Instruct model detected, but chat template not applied. Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`)." "Model appears to be an instruct variant but chat template is not applied. Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`)."
) )
if delete_requests_cache: if delete_requests_cache:
......
...@@ -834,3 +834,21 @@ def resize_image( ...@@ -834,3 +834,21 @@ def resize_image(
# Perform the resize operation with the calculated dimensions # Perform the resize operation with the calculated dimensions
return image.resize((new_width, new_height), resample_filter) return image.resize((new_width, new_height), resample_filter)
def truncate_tokens(
tokens: List[int],
max_length: int,
tokenizer: "PreTrainedTokenizerBase",
strategy: str = "left",
):
if strategy == "left":
return tokens[-max_length:]
elif strategy == "right":
return tokens[:max_length]
elif strategy == "middle":
# Truncate the middle of the sequence
left_length = max_length // 2
right_length = max_length - left_length
return tokens[:left_length] + tokens[-right_length:]
return None
...@@ -614,6 +614,10 @@ class VLLM(TemplateLM): ...@@ -614,6 +614,10 @@ class VLLM(TemplateLM):
# cache generations # cache generations
for output, context in zip(cont, context): for output, context in zip(cont, context):
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until:
if len(term) > 0:
generated_text = generated_text.split(term)[0]
res.append(generated_text) res.append(generated_text)
self.cache_hook.add_partial( self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), generated_text "generate_until", (context, gen_kwargs), generated_text
......
...@@ -6,15 +6,16 @@ dataset_path: THUDM/LongBench ...@@ -6,15 +6,16 @@ dataset_path: THUDM/LongBench
test_split: test test_split: test
dataset_name: 2wikimqa dataset_name: 2wikimqa
doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:' doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:'
doc_to_target: '{{answers[0]}}' doc_to_target: '{{answers}}'
process_results: !function metrics.get_qa_f1_score
generation_kwargs: generation_kwargs:
max_gen_toks: 32 max_gen_toks: 32
temperature: 1 temperature: 1
do_sample: True do_sample: True
until: [] until: []
metric_list: metric_list:
- metric: !function metrics.qa_f1_score - metric: "qa_f1_score"
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
version: 2.0 version: 3.0
...@@ -6,15 +6,16 @@ dataset_path: THUDM/LongBench ...@@ -6,15 +6,16 @@ dataset_path: THUDM/LongBench
test_split: test test_split: test
dataset_name: 2wikimqa_e dataset_name: 2wikimqa_e
doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:' doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:'
doc_to_target: '{{answers[0]}}' doc_to_target: '{{answers}}'
process_results: !function metrics.get_qa_f1_score
generation_kwargs: generation_kwargs:
max_gen_toks: 32 max_gen_toks: 32
temperature: 1 temperature: 1
do_sample: True do_sample: True
until: [] until: []
metric_list: metric_list:
- metric: !function metrics.qa_f1_score - metric: "qa_f1_score"
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
version: 2.0 version: 3.0
...@@ -32,6 +32,17 @@ Homepage: `https://github.com/THUDM/LongBench` ...@@ -32,6 +32,17 @@ Homepage: `https://github.com/THUDM/LongBench`
pages = "3119--3137", pages = "3119--3137",
} }
``` ```
### Notes
#### Tasks without Chat Template (with add_bos_token=True but model dependent)
The original implementation suggest not to use `chat_template` for these tasks (for instruct models):
- longbench_lcc
- longbench_repobench-p
- longbench_samsum
- longbench_trec
- longbench_triviaqa
### Groups, Tags, and Tasks ### Groups, Tags, and Tasks
...@@ -96,3 +107,4 @@ If other tasks on this dataset are already supported: ...@@ -96,3 +107,4 @@ If other tasks on this dataset are already supported:
### Changelog ### Changelog
v2.: fix doc_to_target; add vcsum v2.: fix doc_to_target; add vcsum
v3: properly use all answers for metric calculation; trim whitespace from resps; fix stop sequences not parsing correctly.
...@@ -142,7 +142,6 @@ def parse_args(): ...@@ -142,7 +142,6 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
# Create template string
template_str = """ template_str = """
tag: tag:
- {{ tag[0] }} - {{ tag[0] }}
...@@ -152,11 +151,12 @@ test_split: {{ test_split }} ...@@ -152,11 +151,12 @@ test_split: {{ test_split }}
dataset_name: {{ dataset_name }} dataset_name: {{ dataset_name }}
doc_to_text: '{{ doc_to_text }}' doc_to_text: '{{ doc_to_text }}'
doc_to_target: '{{ doc_to_target }}' doc_to_target: '{{ doc_to_target }}'
process_results: {{ process_results }}
generation_kwargs: generation_kwargs:
max_gen_toks: {{ generation_kwargs.max_gen_toks }} max_gen_toks: {{ generation_kwargs.max_gen_toks }}
temperature: {{ generation_kwargs.temperature }} temperature: {{ generation_kwargs.temperature }}
do_sample: {{ generation_kwargs.do_sample }} do_sample: {{ generation_kwargs.do_sample }}
until: {{ generation_kwargs.until }} until: {% if has_newline %}["\\n"]{% else %}[]{% endif %}
metric_list: metric_list:
- metric: {{ metric_list[0].metric }} - metric: {{ metric_list[0].metric }}
aggregation: {{ metric_list[0].aggregation }} aggregation: {{ metric_list[0].aggregation }}
...@@ -173,21 +173,17 @@ if __name__ == "__main__": ...@@ -173,21 +173,17 @@ if __name__ == "__main__":
for ds in DATASETS: for ds in DATASETS:
df = ds[:-2] if ds.endswith("_e") else ds df = ds[:-2] if ds.endswith("_e") else ds
# from https://github.com/THUDM/LongBench/blob/2e00731f8d0bff23dc4325161044d0ed8af94c1e/LongBench/eval.py#L52C25-L52C29 # from https://github.com/THUDM/LongBench/blob/2e00731f8d0bff23dc4325161044d0ed8af94c1e/LongBench/eval.py#L52C25-L52C29
if df in ["trec", "triviaqa", "samsum", "lsht"] + [
"trec_e", # Now we just set a boolean flag to indicate whether we need a newline
"triviaqa_e", has_newline = df in ["trec", "triviaqa", "samsum", "lsht"]
"samsum_e",
"lsht_e",
]:
until = ["\n"]
else:
until = []
generation_kwargs = { generation_kwargs = {
"max_gen_toks": dataset2maxlen[df], "max_gen_toks": dataset2maxlen[df],
"temperature": 1, "temperature": 1,
"do_sample": True, "do_sample": True,
"until": until, # We'll handle the until value directly in the template
} }
raw_doc_to_text = ( raw_doc_to_text = (
dataset2prompt[df] dataset2prompt[df]
.replace("\n", "\\n") .replace("\n", "\\n")
...@@ -196,25 +192,25 @@ if __name__ == "__main__": ...@@ -196,25 +192,25 @@ if __name__ == "__main__":
) )
metric_list = [ metric_list = [
{ {
"metric": f"!function metrics.{dataset2metric[df]}", "metric": f'"{dataset2metric[df]}"',
"aggregation": "mean", "aggregation": "mean",
"higher_is_better": True, "higher_is_better": True,
} }
] ]
data = { data = {
"tag": [ "tag": ["longbench_e" if ds.endswith("_e") else "longbench"],
"longbench_e" if ds.endswith("_e") else "longbench"
], # Now properly as a list
"task": f"longbench_{ds}", "task": f"longbench_{ds}",
"dataset_path": "THUDM/LongBench", "dataset_path": "THUDM/LongBench",
"test_split": "test", "test_split": "test",
"dataset_name": ds, "dataset_name": ds,
"doc_to_text": raw_doc_to_text, "doc_to_text": raw_doc_to_text,
"doc_to_target": "{{answers[0]}}", "doc_to_target": "{{answers}}",
"process_results": f"!function metrics.get_{dataset2metric[df]}",
"generation_kwargs": generation_kwargs, "generation_kwargs": generation_kwargs,
"has_newline": has_newline, # Add the flag to the template context
"metric_list": metric_list, "metric_list": metric_list,
"metadata": {"version": "2.0"}, "metadata": {"version": "3.0"},
} }
# Render template # Render template
......
...@@ -6,15 +6,16 @@ dataset_path: THUDM/LongBench ...@@ -6,15 +6,16 @@ dataset_path: THUDM/LongBench
test_split: test test_split: test
dataset_name: dureader dataset_name: dureader
doc_to_text: '请基于给定的文章回答下述问题。\n\n文章:{{context}}\n\n请基于上述文章回答下面的问题。\n\n问题:{{input}}\n回答:' doc_to_text: '请基于给定的文章回答下述问题。\n\n文章:{{context}}\n\n请基于上述文章回答下面的问题。\n\n问题:{{input}}\n回答:'
doc_to_target: '{{answers[0]}}' doc_to_target: '{{answers}}'
process_results: !function metrics.get_rouge_zh_score
generation_kwargs: generation_kwargs:
max_gen_toks: 128 max_gen_toks: 128
temperature: 1 temperature: 1
do_sample: True do_sample: True
until: [] until: []
metric_list: metric_list:
- metric: !function metrics.rouge_zh_score - metric: "rouge_zh_score"
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
version: 2.0 version: 3.0
...@@ -6,15 +6,16 @@ dataset_path: THUDM/LongBench ...@@ -6,15 +6,16 @@ dataset_path: THUDM/LongBench
test_split: test test_split: test
dataset_name: gov_report dataset_name: gov_report
doc_to_text: 'You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{{context}}\n\nNow, write a one-page summary of the report.\n\nSummary:' doc_to_text: 'You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{{context}}\n\nNow, write a one-page summary of the report.\n\nSummary:'
doc_to_target: '{{answers[0]}}' doc_to_target: '{{answers}}'
process_results: !function metrics.get_rouge_score
generation_kwargs: generation_kwargs:
max_gen_toks: 512 max_gen_toks: 512
temperature: 1 temperature: 1
do_sample: True do_sample: True
until: [] until: []
metric_list: metric_list:
- metric: !function metrics.rouge_score - metric: "rouge_score"
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
version: 2.0 version: 3.0
...@@ -6,15 +6,16 @@ dataset_path: THUDM/LongBench ...@@ -6,15 +6,16 @@ dataset_path: THUDM/LongBench
test_split: test test_split: test
dataset_name: gov_report_e dataset_name: gov_report_e
doc_to_text: 'You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{{context}}\n\nNow, write a one-page summary of the report.\n\nSummary:' doc_to_text: 'You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{{context}}\n\nNow, write a one-page summary of the report.\n\nSummary:'
doc_to_target: '{{answers[0]}}' doc_to_target: '{{answers}}'
process_results: !function metrics.get_rouge_score
generation_kwargs: generation_kwargs:
max_gen_toks: 512 max_gen_toks: 512
temperature: 1 temperature: 1
do_sample: True do_sample: True
until: [] until: []
metric_list: metric_list:
- metric: !function metrics.rouge_score - metric: "rouge_score"
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
version: 2.0 version: 3.0
...@@ -6,15 +6,16 @@ dataset_path: THUDM/LongBench ...@@ -6,15 +6,16 @@ dataset_path: THUDM/LongBench
test_split: test test_split: test
dataset_name: hotpotqa dataset_name: hotpotqa
doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:' doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:'
doc_to_target: '{{answers[0]}}' doc_to_target: '{{answers}}'
process_results: !function metrics.get_qa_f1_score
generation_kwargs: generation_kwargs:
max_gen_toks: 32 max_gen_toks: 32
temperature: 1 temperature: 1
do_sample: True do_sample: True
until: [] until: []
metric_list: metric_list:
- metric: !function metrics.qa_f1_score - metric: "qa_f1_score"
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
version: 2.0 version: 3.0
...@@ -6,15 +6,16 @@ dataset_path: THUDM/LongBench ...@@ -6,15 +6,16 @@ dataset_path: THUDM/LongBench
test_split: test test_split: test
dataset_name: hotpotqa_e dataset_name: hotpotqa_e
doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:' doc_to_text: 'Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{{context}}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {{input}}\nAnswer:'
doc_to_target: '{{answers[0]}}' doc_to_target: '{{answers}}'
process_results: !function metrics.get_qa_f1_score
generation_kwargs: generation_kwargs:
max_gen_toks: 32 max_gen_toks: 32
temperature: 1 temperature: 1
do_sample: True do_sample: True
until: [] until: []
metric_list: metric_list:
- metric: !function metrics.qa_f1_score - metric: "qa_f1_score"
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
version: 2.0 version: 3.0
...@@ -6,15 +6,16 @@ dataset_path: THUDM/LongBench ...@@ -6,15 +6,16 @@ dataset_path: THUDM/LongBench
test_split: test test_split: test
dataset_name: lcc dataset_name: lcc
doc_to_text: 'Please complete the code given below. \n{{context}}Next line of code:\n' doc_to_text: 'Please complete the code given below. \n{{context}}Next line of code:\n'
doc_to_target: '{{answers[0]}}' doc_to_target: '{{answers}}'
process_results: !function metrics.get_code_sim_score
generation_kwargs: generation_kwargs:
max_gen_toks: 64 max_gen_toks: 64
temperature: 1 temperature: 1
do_sample: True do_sample: True
until: [] until: []
metric_list: metric_list:
- metric: !function metrics.code_sim_score - metric: "code_sim_score"
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
version: 2.0 version: 3.0
...@@ -6,15 +6,16 @@ dataset_path: THUDM/LongBench ...@@ -6,15 +6,16 @@ dataset_path: THUDM/LongBench
test_split: test test_split: test
dataset_name: lcc_e dataset_name: lcc_e
doc_to_text: 'Please complete the code given below. \n{{context}}Next line of code:\n' doc_to_text: 'Please complete the code given below. \n{{context}}Next line of code:\n'
doc_to_target: '{{answers[0]}}' doc_to_target: '{{answers}}'
process_results: !function metrics.get_code_sim_score
generation_kwargs: generation_kwargs:
max_gen_toks: 64 max_gen_toks: 64
temperature: 1 temperature: 1
do_sample: True do_sample: True
until: [] until: []
metric_list: metric_list:
- metric: !function metrics.code_sim_score - metric: "code_sim_score"
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
version: 2.0 version: 3.0
...@@ -6,16 +6,16 @@ dataset_path: THUDM/LongBench ...@@ -6,16 +6,16 @@ dataset_path: THUDM/LongBench
test_split: test test_split: test
dataset_name: lsht dataset_name: lsht
doc_to_text: '请判断给定新闻的类别,下面是一些例子。\n\n{{context}}\n{{input}}' doc_to_text: '请判断给定新闻的类别,下面是一些例子。\n\n{{context}}\n{{input}}'
doc_to_target: '{{answers[0]}}' doc_to_target: '{{answers}}'
process_results: !function metrics.classification_score process_results: !function metrics.get_classification_score
generation_kwargs: generation_kwargs:
max_gen_toks: 64 max_gen_toks: 64
temperature: 1 temperature: 1
do_sample: True do_sample: True
until: ['\n'] until: ["\n"]
metric_list: metric_list:
- metric: "classification_score" - metric: "classification_score"
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
version: 2.0 version: 3.0
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
import re import re
import string import string
from collections import Counter from collections import Counter
from typing import Union
try: try:
import jieba import jieba
...@@ -33,7 +34,7 @@ except ImportError: ...@@ -33,7 +34,7 @@ except ImportError:
'Please install the required dependencies for this task with `pip install lm_eval["longbench"] or `pip install jieba fuzzywuzzy rouge`' 'Please install the required dependencies for this task with `pip install lm_eval["longbench"] or `pip install jieba fuzzywuzzy rouge`'
) )
# taken from https://github.com/THUDM/LongBench # taken and slightly modified from https://github.com/THUDM/LongBench
def normalize_answer(s: str) -> str: def normalize_answer(s: str) -> str:
...@@ -72,8 +73,7 @@ def normalize_zh_answer(s: str) -> str: ...@@ -72,8 +73,7 @@ def normalize_zh_answer(s: str) -> str:
return white_space_fix(remove_punc(lower(s))) return white_space_fix(remove_punc(lower(s)))
def count_score(predictions: list[str], references: list[str], **kwargs) -> float: def count_score(prediction: str, ground_truth: str, **kwargs):
prediction, ground_truth = predictions[0], references[0]
numbers = re.findall(r"\d+", prediction) numbers = re.findall(r"\d+", prediction)
right_num = 0 right_num = 0
for number in numbers: for number in numbers:
...@@ -83,8 +83,16 @@ def count_score(predictions: list[str], references: list[str], **kwargs) -> floa ...@@ -83,8 +83,16 @@ def count_score(predictions: list[str], references: list[str], **kwargs) -> floa
return float(final_score) return float(final_score)
def retrieval_score(predictions: list[str], references: list[str], **kwargs) -> float: def get_count_score(doc: dict, results: list[str], **kwargs):
prediction, ground_truth = predictions[0], references[0] output = 0.0
prediction = results[0].strip()
for ground_truth in doc["answers"]:
score = count_score(prediction, ground_truth)
output = max(score, output)
return {"count_score": output}
def retrieval_score(prediction: str, ground_truth: str, **kwargs):
pattern = r"Paragraph (\d+)" pattern = r"Paragraph (\d+)"
matches = re.findall(pattern, ground_truth) matches = re.findall(pattern, ground_truth)
ground_truth_id = matches[0] ground_truth_id = matches[0]
...@@ -97,10 +105,16 @@ def retrieval_score(predictions: list[str], references: list[str], **kwargs) -> ...@@ -97,10 +105,16 @@ def retrieval_score(predictions: list[str], references: list[str], **kwargs) ->
return float(final_score) return float(final_score)
def retrieval_zh_score( def get_retrieval_score(doc: dict, results: list[str], **kwargs):
predictions: list[str], references: list[str], **kwargs output = 0.0
) -> float: prediction = results[0].strip()
prediction, ground_truth = predictions[0], references[0] for ground_truth in doc["answers"]:
score = retrieval_score(prediction, ground_truth)
output = max(score, output)
return {"retrieval_score": output}
def retrieval_zh_score(prediction: str, ground_truth: str, **kwargs):
pattern = r"段落(\d+)" pattern = r"段落(\d+)"
matches = re.findall(pattern, ground_truth) matches = re.findall(pattern, ground_truth)
ground_truth_id = matches[0] ground_truth_id = matches[0]
...@@ -113,8 +127,16 @@ def retrieval_zh_score( ...@@ -113,8 +127,16 @@ def retrieval_zh_score(
return float(final_score) return float(final_score)
def code_sim_score(predictions: list[str], references: list[str], **kwargs) -> float: def get_retrieval_zh_score(doc: dict, results: list[str], **kwargs):
prediction, ground_truth = predictions[0], references[0] output = 0.0
prediction = results[0].strip()
for ground_truth in doc["answers"]:
score = retrieval_zh_score(prediction, ground_truth)
output = max(score, output)
return {"retrieval_zh_score": output}
def code_sim_score(prediction: str, ground_truth: str, **kwargs):
all_lines = prediction.lstrip("\n").split("\n") all_lines = prediction.lstrip("\n").split("\n")
prediction = "" prediction = ""
for line in all_lines: for line in all_lines:
...@@ -124,10 +146,18 @@ def code_sim_score(predictions: list[str], references: list[str], **kwargs) -> f ...@@ -124,10 +146,18 @@ def code_sim_score(predictions: list[str], references: list[str], **kwargs) -> f
return fuzz.ratio(prediction, ground_truth) / 100 return fuzz.ratio(prediction, ground_truth) / 100
def classification_score(doc: dict, results: list[str], **kwargs) -> dict: def get_code_sim_score(doc: dict, results: list[str], **kwargs):
prediction, ground_truth = results[0], doc["answers"][0] output = 0.0
prediction = results[0] ## important! do not strip the prediction!
for ground_truth in doc["answers"]:
score = code_sim_score(prediction, ground_truth)
output = max(score, output)
return {"code_sim_score": output}
def classification_score(prediction: str, ground_truth: str, **kwargs):
em_match_list = [] em_match_list = []
all_classes = doc["all_classes"] all_classes = kwargs["all_classes"]
for class_name in all_classes: for class_name in all_classes:
if class_name in prediction: if class_name in prediction:
em_match_list.append(class_name) em_match_list.append(class_name)
...@@ -138,35 +168,58 @@ def classification_score(doc: dict, results: list[str], **kwargs) -> dict: ...@@ -138,35 +168,58 @@ def classification_score(doc: dict, results: list[str], **kwargs) -> dict:
score = 1.0 / len(em_match_list) score = 1.0 / len(em_match_list)
else: else:
score = 0.0 score = 0.0
return {"classification_score": score} return score
def get_classification_score(doc: dict, results: list[str]) -> dict:
output = 0.0
prediction = results[0].strip()
for ground_truth in doc["answers"]:
score = classification_score(
prediction, ground_truth, all_classes=doc["all_classes"]
)
output = max(score, output)
return {"classification_score": output}
def rouge_score(predictions: list[str], references: list[str], **kwargs) -> float: def rouge_score(predictions: str, ground_truth: str, **kwargs) -> float:
global rouge global rouge
if "rouge" not in globals(): if "rouge" not in globals():
rouge = Rouge() rouge = Rouge()
prediction, ground_truth = predictions[0], references[0]
try: try:
scores = rouge.get_scores([prediction], [ground_truth], avg=True) scores = rouge.get_scores([predictions], [ground_truth], avg=True)
# ruff: noqa # ruff: noqa
except: except:
return 0.0 return 0.0
return scores["rouge-l"]["f"] return scores["rouge-l"]["f"]
def rouge_zh_score(predictions: list[str], references: list[str], **kwargs) -> float: def get_rouge_score(doc: dict, results: list[str], **kwargs):
prediction, ground_truth = predictions[0], references[0] output = 0.0
prediction = results[0].strip()
for ground_truth in doc["answers"]:
score = rouge_score(prediction, ground_truth)
output = max(score, output)
return {"rouge_score": output}
def rouge_zh_score(prediction: str, ground_truth: str, **kwargs):
prediction = " ".join(list(jieba.cut(prediction, cut_all=False))) prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False))) ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False)))
score = rouge_score([prediction], [ground_truth]) score = rouge_score(prediction, ground_truth)
return score return score
def f1_score(predictions: list[str], references: list[str], **kwargs) -> float: def get_rouge_zh_score(doc, results, **kwargs):
try: output = 0.0
prediction, ground_truth = predictions[0], references[0] prediction = results[0].strip()
except: for ground_truth in doc["answers"]:
return 0.0 score = rouge_zh_score(prediction, ground_truth)
output = max(score, output)
return {"rouge_zh_score": output}
def f1_score(prediction: Union[str, list], ground_truth: Union[str, list], **kwargs):
common = Counter(prediction) & Counter(ground_truth) common = Counter(prediction) & Counter(ground_truth)
num_same = sum(common.values()) num_same = sum(common.values())
if num_same == 0: if num_same == 0:
...@@ -177,22 +230,25 @@ def f1_score(predictions: list[str], references: list[str], **kwargs) -> float: ...@@ -177,22 +230,25 @@ def f1_score(predictions: list[str], references: list[str], **kwargs) -> float:
return f1 return f1
def qa_f1_score(predictions: list[str], references: list[str], **kwargs) -> float: def get_f1_score(doc: dict, results: list[str], **kwargs):
prediction, ground_truth = predictions[0], references[0] output = 0.0
prediction = results[0].strip()
for ground_truth in doc["answers"]:
score = f1_score(prediction, ground_truth)
output = max(score, output)
return {"f1_score": output}
def qa_f1_score(prediction: str, ground_truth: str, **kwargs):
normalized_prediction = normalize_answer(prediction) normalized_prediction = normalize_answer(prediction)
normalized_ground_truth = normalize_answer(ground_truth) normalized_ground_truth = normalize_answer(ground_truth)
prediction_tokens = normalized_prediction.split() prediction_tokens = normalized_prediction.split()
ground_truth_tokens = normalized_ground_truth.split() ground_truth_tokens = normalized_ground_truth.split()
try: return f1_score(prediction_tokens, ground_truth_tokens)
res = f1_score(prediction_tokens, ground_truth_tokens)
except:
return 0.0
return res
def qa_f1_zh_score(predictions: list[str], references: list[str], **kwargs) -> float: def qa_f1_zh_score(prediction: str, ground_truth: str, **kwargs):
prediction, ground_truth = predictions[0], references[0]
prediction_tokens = list(jieba.cut(prediction, cut_all=False)) prediction_tokens = list(jieba.cut(prediction, cut_all=False))
ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False)) ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False))
prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens] prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens]
...@@ -200,3 +256,21 @@ def qa_f1_zh_score(predictions: list[str], references: list[str], **kwargs) -> f ...@@ -200,3 +256,21 @@ def qa_f1_zh_score(predictions: list[str], references: list[str], **kwargs) -> f
prediction_tokens = [token for token in prediction_tokens if len(token) > 0] prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0] ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
return f1_score(prediction_tokens, ground_truth_tokens) return f1_score(prediction_tokens, ground_truth_tokens)
def get_qa_f1_score(doc: dict, results: list[str], **kwargs):
output = 0.0
prediction = results[0].strip()
for ground_truth in doc["answers"]:
score = qa_f1_score(prediction, ground_truth)
output = max(score, output)
return {"qa_f1_score": output}
def get_qa_f1_zh_score(doc: dict, results: list[str], **kwargs):
output = 0.0
prediction = results[0].strip()
for ground_truth in doc["answers"]:
score = qa_f1_zh_score(prediction, ground_truth)
output = max(score, output)
return {"qa_f1_zh_score": output}
...@@ -6,15 +6,16 @@ dataset_path: THUDM/LongBench ...@@ -6,15 +6,16 @@ dataset_path: THUDM/LongBench
test_split: test test_split: test
dataset_name: multi_news dataset_name: multi_news
doc_to_text: 'You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{{context}}\n\nNow, write a one-page summary of all the news.\n\nSummary:' doc_to_text: 'You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{{context}}\n\nNow, write a one-page summary of all the news.\n\nSummary:'
doc_to_target: '{{answers[0]}}' doc_to_target: '{{answers}}'
process_results: !function metrics.get_rouge_score
generation_kwargs: generation_kwargs:
max_gen_toks: 512 max_gen_toks: 512
temperature: 1 temperature: 1
do_sample: True do_sample: True
until: [] until: []
metric_list: metric_list:
- metric: !function metrics.rouge_score - metric: "rouge_score"
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
version: 2.0 version: 3.0
...@@ -6,15 +6,16 @@ dataset_path: THUDM/LongBench ...@@ -6,15 +6,16 @@ dataset_path: THUDM/LongBench
test_split: test test_split: test
dataset_name: multi_news_e dataset_name: multi_news_e
doc_to_text: 'You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{{context}}\n\nNow, write a one-page summary of all the news.\n\nSummary:' doc_to_text: 'You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{{context}}\n\nNow, write a one-page summary of all the news.\n\nSummary:'
doc_to_target: '{{answers[0]}}' doc_to_target: '{{answers}}'
process_results: !function metrics.get_rouge_score
generation_kwargs: generation_kwargs:
max_gen_toks: 512 max_gen_toks: 512
temperature: 1 temperature: 1
do_sample: True do_sample: True
until: [] until: []
metric_list: metric_list:
- metric: !function metrics.rouge_score - metric: "rouge_score"
aggregation: mean aggregation: mean
higher_is_better: True higher_is_better: True
metadata: metadata:
version: 2.0 version: 3.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment