# SPDX-License-Identifier: Apache-2.0 """LongBench evaluation via vLLM ``LLM.generate`` + kvprune compression (same folder layout as RULER).""" from __future__ import annotations import json import logging import os import sys from pathlib import Path from datasets import concatenate_datasets, load_dataset _SCRIPT_DIR = Path(__file__).resolve().parent if str(_SCRIPT_DIR) not in sys.path: sys.path.insert(0, str(_SCRIPT_DIR)) from longbench_metrics import dataset2metric # noqa: E402 from vllm import LLM, SamplingParams # noqa: E402 from vllm.kvprune.integration.compression_params import CompressionParams # noqa: E402 def _hf_tokenizer(llm: LLM): tok = llm.get_tokenizer() return getattr(tok, "tokenizer", tok) def messages_to_prompts( llm: LLM, messages: list[list[dict]], *, add_generation_prompt: bool, enable_thinking: bool, ) -> list[str]: inner = _hf_tokenizer(llm) out: list[str] = [] kw: dict = {} if enable_thinking: kw["enable_thinking"] = True for conv in messages: text = inner.apply_chat_template( conv, tokenize=False, add_generation_prompt=add_generation_prompt, **kw, ) out.append(text) return out if __name__ == "__main__": logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s" ) cfg_dir = _SCRIPT_DIR / "longbench_config" prompts = json.load(open(cfg_dir / "dataset2prompt.json", "r", encoding="utf-8")) max_gen_lens = json.load(open(cfg_dir / "dataset2maxlen.json", "r", encoding="utf-8")) datasets = [ "narrativeqa", "qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "musique", "gov_report", "qmsum", "multi_news", "trec", "triviaqa", "samsum", "passage_retrieval_en", "passage_count", "lcc", "repobench-p", ] dataset = concatenate_datasets( [ load_dataset("THUDM/LongBench", n, split="test", trust_remote_code=True) for n in datasets ] ).shuffle(seed=42) dset_names = [ item["dataset"] if item["dataset"][-2:] != "_e" else item["dataset"][:-2] for item in dataset ] gen_lengths = [max_gen_lens[dset_name] for dset_name in dset_names] messages = [ [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompts[dset_name].format(**item)}, ] for dset_name, item in zip(dset_names, dataset) ] model = os.environ.get("KVPRUNE_EVAL_MODEL", "meta-llama/Llama-3.1-8B-Instruct") tp = int(os.environ.get("KVPRUNE_EVAL_TP", "2")) seq_ratio = float(os.environ.get("KVPRUNE_SEQ_COMPRESSION_RATIO", "0.25")) llm = LLM( model=model, max_num_seqs=64, gpu_memory_utilization=0.95, tensor_parallel_size=tp, max_model_len=128000, kvprune_compression=True, ) text_prompts = messages_to_prompts( llm, messages, add_generation_prompt=True, enable_thinking=False, ) sampling_params = [ SamplingParams(max_tokens=g, temperature=0.00001) for g in gen_lengths ] n = len(text_prompts) compression = [ CompressionParams( compression_ratio=seq_ratio, compression_method="compactor", protected_first_tokens=8, protected_last_tokens=64, ) ] * n outputs = llm.generate(text_prompts, sampling_params, compression=compression) responses = [o.outputs[0].text for o in outputs] results: dict = {} for dset_name, prediction, item in zip(dset_names, responses, dataset): results.setdefault(dset_name, []) pred = prediction if dset_name in ["trec", "triviaqa", "samsum", "lsht"]: pred = pred.lstrip("\n").split("\n")[0] score = 0.0 for ground_truth in item["answers"]: score = max( score, dataset2metric[dset_name]( pred, ground_truth, all_classes=item["all_classes"] ), ) results[dset_name].append(score) all_sum, all_count = 0, 0 for task, scores in results.items(): avg = sum(scores) / len(scores) print(task, f"{avg:.2f}") all_sum += sum(scores) all_count += len(scores) print(f"ALL: {all_sum / all_count:.2f}")