"docs/vscode:/vscode.git/clone" did not exist on "0d2d424fbef933e4b81bea20a660ee6fc8b75ab0"
Unverified Commit 0aa65f94 authored by Binyao Jiang's avatar Binyao Jiang Committed by GitHub
Browse files

[Fix] Improve longbench prompt and other logics (#11474)

parent 0ecb4261
...@@ -103,6 +103,7 @@ def run_eval(args): ...@@ -103,6 +103,7 @@ def run_eval(args):
categories = args.categories.split(",") if args.categories else None categories = args.categories.split(",") if args.categories else None
eval_obj = LongBenchV2Eval( eval_obj = LongBenchV2Eval(
model=args.model,
data_source=data_source, data_source=data_source,
num_examples=args.num_examples, num_examples=args.num_examples,
num_threads=args.num_threads, num_threads=args.num_threads,
......
...@@ -290,6 +290,9 @@ def aggregate_results( ...@@ -290,6 +290,9 @@ def aggregate_results(
htmls = [] htmls = []
convos = [] convos = []
for single_eval_result in single_eval_results: for single_eval_result in single_eval_results:
# Skip None results
if single_eval_result is None:
continue
for name, value in single_eval_result.metrics.items(): for name, value in single_eval_result.metrics.items():
name2values[name].append(value) name2values[name].append(value)
if single_eval_result.score is not None: if single_eval_result.score is not None:
......
...@@ -12,6 +12,8 @@ import os ...@@ -12,6 +12,8 @@ import os
import re import re
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from transformers import AutoTokenizer
from sglang.test import simple_eval_common as common from sglang.test import simple_eval_common as common
from sglang.test.simple_eval_common import ( from sglang.test.simple_eval_common import (
ANSWER_PATTERN_MULTICHOICE, ANSWER_PATTERN_MULTICHOICE,
...@@ -55,7 +57,11 @@ def format_longbench_v2_question(row: dict) -> str: ...@@ -55,7 +57,11 @@ def format_longbench_v2_question(row: dict) -> str:
choice_D = row.get("D", row.get("choice_D", "")) choice_D = row.get("D", row.get("choice_D", ""))
# Official LongBench-v2 template # Official LongBench-v2 template
prompt = f"""{context.strip()} prompt = f"""
Please read the following text and answer the question below.
<text>
{context.strip()}
</text>
What is the correct answer to this question: {question.strip()} What is the correct answer to this question: {question.strip()}
Choices: Choices:
...@@ -64,7 +70,7 @@ Choices: ...@@ -64,7 +70,7 @@ Choices:
(C) {choice_C.strip()} (C) {choice_C.strip()}
(D) {choice_D.strip()} (D) {choice_D.strip()}
The correct answer is""" Format your response as follows: "The correct answer is (insert answer here)"."""
return prompt return prompt
...@@ -106,6 +112,7 @@ class LongBenchV2Eval(Eval): ...@@ -106,6 +112,7 @@ class LongBenchV2Eval(Eval):
def __init__( def __init__(
self, self,
model: str = None,
data_source: str = DEFAULT_DATASET, data_source: str = DEFAULT_DATASET,
num_examples: Optional[int] = None, num_examples: Optional[int] = None,
num_threads: int = 1, num_threads: int = 1,
...@@ -126,6 +133,9 @@ class LongBenchV2Eval(Eval): ...@@ -126,6 +133,9 @@ class LongBenchV2Eval(Eval):
max_context_length: Maximum context length in characters max_context_length: Maximum context length in characters
min_context_length: Minimum context length in characters min_context_length: Minimum context length in characters
""" """
self.tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
self.min_context_length = min_context_length
self.max_context_length = max_context_length
# Load dataset based on data source type # Load dataset based on data source type
examples = self._load_dataset(data_source) examples = self._load_dataset(data_source)
...@@ -133,11 +143,6 @@ class LongBenchV2Eval(Eval): ...@@ -133,11 +143,6 @@ class LongBenchV2Eval(Eval):
if categories: if categories:
examples = [ex for ex in examples if ex.get("category") in categories] examples = [ex for ex in examples if ex.get("category") in categories]
if min_context_length or max_context_length:
examples = self._filter_by_context_length(
examples, min_context_length, max_context_length
)
# Sample examples if specified # Sample examples if specified
if num_examples: if num_examples:
assert n_repeats == 1, "n_repeats only supported when not sampling examples" assert n_repeats == 1, "n_repeats only supported when not sampling examples"
...@@ -246,26 +251,23 @@ class LongBenchV2Eval(Eval): ...@@ -246,26 +251,23 @@ class LongBenchV2Eval(Eval):
return normalized return normalized
def _filter_by_context_length( def _check_context_length(
self, self,
examples: List[Dict[str, Any]], formatted_question: str,
tokenizer: AutoTokenizer,
min_length: Optional[int], min_length: Optional[int],
max_length: Optional[int], max_length: Optional[int],
) -> List[Dict[str, Any]]: ) -> bool:
"""Filter examples by context length measured in characters.""" """Filter examples by context length measured in characters."""
filtered = [] input_ids = tokenizer.encode(formatted_question)
for example in examples: context_length = len(input_ids)
context = example.get("context", "")
context_length = len(context)
if min_length is not None and context_length < min_length:
continue
if max_length is not None and context_length > max_length:
continue
filtered.append(example) if min_length is not None and context_length < min_length:
return False
if max_length is not None and context_length > max_length:
return False
return filtered return True
def __call__(self, sampler: SamplerBase) -> EvalResult: def __call__(self, sampler: SamplerBase) -> EvalResult:
"""Run the evaluation.""" """Run the evaluation."""
...@@ -274,6 +276,16 @@ class LongBenchV2Eval(Eval): ...@@ -274,6 +276,16 @@ class LongBenchV2Eval(Eval):
# Format the question using official template # Format the question using official template
formatted_question = format_longbench_v2_question(row) formatted_question = format_longbench_v2_question(row)
if self.min_context_length or self.max_context_length:
if not self._check_context_length(
formatted_question,
self.tokenizer,
self.min_context_length,
self.max_context_length,
):
# Skip this example
return None
prompt_messages = [ prompt_messages = [
sampler._pack_message(content=formatted_question, role="user") sampler._pack_message(content=formatted_question, role="user")
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment