# SPDX-License-Identifier: Apache-2.0 """RULER evaluation using vLLM ``LLM.generate`` + integrated kvprune (compactor) compression. Run from the **repository root** (or any cwd if ``vllm`` is installed), e.g.:: python tests/kvprune/evaluate/eval_ruler.py \\ --dataset-parquet tests/kvprune/evaluate/test-00000-of-00001.parquet \\ --dataset-split train \\ --model Qwen/Qwen3-8B \\ --compression-method compactor \\ --seq-compression-ratio 0.5 Set ``VLLM_KVPRUNE_ATTENTION_SCHEDULE`` (``fa_triton`` | ``pdtriton`` | ``pdfa``) **before** starting Python if you need a specific attention schedule (also supported via ``--attention-schedule``). """ from __future__ import annotations import argparse import json import logging import os import sys from datetime import datetime from pathlib import Path import torch from datasets import load_dataset # Local metrics (same directory as this script) _SCRIPT_DIR = Path(__file__).resolve().parent if str(_SCRIPT_DIR) not in sys.path: sys.path.insert(0, str(_SCRIPT_DIR)) from ruler_metrics import score_function # noqa: E402 from vllm import LLM, SamplingParams # noqa: E402 from vllm.kvprune.integration.compression_params import CompressionParams # noqa: E402 def _hf_tokenizer(llm: LLM): tok = llm.get_tokenizer() return getattr(tok, "tokenizer", tok) def messages_to_prompts( llm: LLM, messages: list[list[dict]], *, add_generation_prompt: bool, continue_final_message: bool, enable_thinking: bool, ) -> list[str]: """Render chat messages to a single prompt string per conversation (HF template).""" inner = _hf_tokenizer(llm) out: list[str] = [] for conv in messages: kw: dict = {} if enable_thinking: kw["enable_thinking"] = True text = inner.apply_chat_template( conv, tokenize=False, add_generation_prompt=add_generation_prompt, continue_final_message=continue_final_message, **kw, ) out.append(text) return out def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="RULER evaluation with vLLM kvprune (integrated compactor) compression." ) parser.add_argument( "--log-level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Logging level.", ) parser.add_argument( "--dataset-length", type=str, default="4096", help="Dataset configuration name (metadata / output filenames only when using HF hub).", ) parser.add_argument( "--dataset-parquet", type=str, default=None, help=( "Local Parquet path (single file or glob). If set, loads via datasets parquet " "instead of simonjegou/ruler." ), ) parser.add_argument( "--dataset-split", type=str, default="test", help="Split name (local parquet often uses 'train').", ) parser.add_argument("--seed", type=int, default=42, help="Shuffle seed.") parser.add_argument( "--fraction", type=float, default=1.0, help="Fraction of dataset to use in (0, 1].", ) parser.add_argument("--model", type=str, default="Qwen/Qwen3-8B", help="HF model id or path.") parser.add_argument("--max-num-seqs", type=int, default=32, help="vLLM max_num_seqs.") parser.add_argument( "--gpu-memory-utilization", type=float, default=0.95, help="GPU memory fraction." ) parser.add_argument( "--tensor-parallel-size", type=int, default=1, help=( "vLLM tensor parallel size. Default 1 uses the in-process shared-weight " "compactor on one GPU. For multi-GPU (e.g. 4), set this to the number of " "GPUs; compression then uses the TP collective_rpc path on workers." ), ) parser.add_argument("--max-model-len", type=int, default=40960, help="max_model_len.") parser.add_argument( "--enforce-eager", action="store_true", help="vLLM enforce_eager (on by default when --kvprune-compression).", ) parser.add_argument( "--kvprune-compression", action=argparse.BooleanOptionalAction, default=True, help="Enable kvprune_compression on LLM (skip v1 CUDA graphs, minimal v1 KV blocks). " "Default: True.", ) parser.add_argument( "--attention-schedule", type=str, default=None, help=( "If set, assigns VLLM_KVPRUNE_ATTENTION_SCHEDULE before engine init, e.g. " "fa_triton, pdtriton, pdfa (see vllm/kvprune/integration/config_adapter.py)." ), ) parser.add_argument("--max-tokens", type=int, default=256, help="max_tokens (generation).") parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature.") parser.add_argument( "--compression-method", type=str, default="compactor", choices=["compactor", "criticaladakv", "snapkv"], help="kvprune compression method alias.", ) parser.add_argument( "--seq-compression-ratio", type=float, default=0.5, help="Per-sequence compression ratio in (0, 1].", ) parser.add_argument( "--protected-first-tokens", type=int, default=8, help="Protected prefix token count for pruning.", ) parser.add_argument( "--extra-protected-last-tokens", type=int, default=16, help="Added to tokenized(answer_prefix+question) length for protected_last_tokens.", ) parser.add_argument( "--tokenizer-add-generation-prompt", action="store_true", help="apply_chat_template add_generation_prompt=True.", ) parser.add_argument( "--tokenizer-enable-thinking", action="store_true", help="apply_chat_template enable_thinking=True (Qwen3).", ) parser.add_argument( "--no-tokenizer-continue-final-message", dest="tokenizer_continue_final_message", action="store_false", help="continue_final_message=False (default True).", ) parser.set_defaults(tokenizer_continue_final_message=True) parser.add_argument( "--results-dir", type=str, default="results", help="Directory for JSON summary and JSONL details.", ) return parser.parse_args() def main() -> None: args = parse_args() if args.attention_schedule: os.environ["VLLM_KVPRUNE_ATTENTION_SCHEDULE"] = args.attention_schedule.strip() torch.manual_seed(args.seed) logging.basicConfig( level=getattr(logging, args.log_level.upper(), logging.INFO), format="%(asctime)s %(levelname)s: %(message)s", ) logger = logging.getLogger(__name__) if args.dataset_parquet: logger.info( "Loading local parquet from %s (split=%s)", args.dataset_parquet, args.dataset_split, ) dataset = load_dataset( "parquet", data_files=args.dataset_parquet, split=args.dataset_split, ) else: logger.info( "Loading simonjegou/ruler length=%s split=%s", args.dataset_length, args.dataset_split, ) dataset = load_dataset( "simonjegou/ruler", args.dataset_length, split=args.dataset_split, ) if args.seed is not None and args.seed >= 0: dataset = dataset.shuffle(seed=args.seed) if not (0 < args.fraction <= 1.0): raise ValueError("--fraction must be in (0, 1].") if args.fraction < 1.0: n_examples = max(1, int(len(dataset) * args.fraction)) dataset = dataset.select(range(n_examples)) logger.info("Examples: %d", len(dataset)) messages = [ [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": example["context"] + " " + example["question"]}, {"role": "assistant", "content": example["answer_prefix"]}, ] for example in dataset ] llm = LLM( model=args.model, tensor_parallel_size=args.tensor_parallel_size, max_model_len=args.max_model_len, max_num_seqs=args.max_num_seqs, gpu_memory_utilization=args.gpu_memory_utilization, enforce_eager=args.enforce_eager or args.kvprune_compression, kvprune_compression=args.kvprune_compression, ) tok = _hf_tokenizer(llm) end_protected_lengths = [ args.extra_protected_last_tokens + len( tok.encode( example["answer_prefix"] + example["question"], add_special_tokens=False, ) ) for example in dataset ] prompts = messages_to_prompts( llm, messages, add_generation_prompt=args.tokenizer_add_generation_prompt, continue_final_message=args.tokenizer_continue_final_message, enable_thinking=args.tokenizer_enable_thinking, ) sampling_params = SamplingParams( max_tokens=args.max_tokens, temperature=args.temperature, ) compression_list = [ CompressionParams( compression_ratio=args.seq_compression_ratio, compression_method=args.compression_method, protected_first_tokens=args.protected_first_tokens, protected_last_tokens=end_protected_lengths[i], ) for i in range(len(prompts)) ] logger.info("Running LLM.generate with kvprune compression on %d prompts.", len(prompts)) outputs = llm.generate( prompts, sampling_params, compression=compression_list, ) responses = [o.outputs[0].text.strip() for o in outputs] logger.info("Scoring responses.") results: dict = {} per_example: list = [] all_sum, all_count = 0.0, 0 for idx, (example, response) in enumerate(zip(dataset, responses)): task = example["task"] answer = example["answer"] score = score_function( generated=response, ground_truth=answer, task_category=task, ) results.setdefault(task, []).append(score) all_sum += score all_count += 1 per_example.append( { "index": idx, "task": task, "context": example["context"], "question": example["question"], "answer_prefix": example["answer_prefix"], "ground_truth": answer, "generated": response, "score": score, "compression_params": { "seq_compression_ratio": args.seq_compression_ratio, "compression_method": args.compression_method, "protected_first_tokens": args.protected_first_tokens, "protected_last_tokens": end_protected_lengths[idx], }, } ) per_task_summary = {} for task, scores in results.items(): avg = sum(scores) / len(scores) print(task, f"{avg:.3f}") per_task_summary[task] = { "avg_score": avg, "num_examples": len(scores), "sum_scores": sum(scores), } overall_avg = all_sum / all_count if all_count > 0 else 0.0 print(f"ALL: {overall_avg:.3f}") os.makedirs(args.results_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") safe_model_name = args.model.replace("/", "_") base_name = f"ruler_{args.dataset_length}_{safe_model_name}_{timestamp}" summary_path = os.path.join(args.results_dir, base_name + "_summary.json") details_path = os.path.join(args.results_dir, base_name + "_details.jsonl") ds_name = args.dataset_parquet or "simonjegou/ruler" with open(summary_path, "w", encoding="utf-8") as f: json.dump( { "timestamp": timestamp, "model": args.model, "dataset": ds_name, "dataset_length": args.dataset_length, "num_examples": len(dataset), "overall_avg_score": overall_avg, "per_task": per_task_summary, "arguments": vars(args), }, f, ensure_ascii=False, indent=2, ) with open(details_path, "w", encoding="utf-8") as f: for row in per_example: f.write(json.dumps(row, ensure_ascii=False) + "\n") logger.info("Wrote %s and %s", summary_path, details_path) if __name__ == "__main__": main()