eval_longbench.py 4.44 KB
Newer Older
chenzk's avatar
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# SPDX-License-Identifier: Apache-2.0
"""LongBench evaluation via vLLM ``LLM.generate`` + kvprune compression (same folder layout as RULER)."""
from __future__ import annotations

import json
import logging
import os
import sys
from pathlib import Path

from datasets import concatenate_datasets, load_dataset

_SCRIPT_DIR = Path(__file__).resolve().parent
if str(_SCRIPT_DIR) not in sys.path:
    sys.path.insert(0, str(_SCRIPT_DIR))
from longbench_metrics import dataset2metric  # noqa: E402

from vllm import LLM, SamplingParams  # noqa: E402
from vllm.kvprune.integration.compression_params import CompressionParams  # noqa: E402


def _hf_tokenizer(llm: LLM):
    tok = llm.get_tokenizer()
    return getattr(tok, "tokenizer", tok)


def messages_to_prompts(
    llm: LLM,
    messages: list[list[dict]],
    *,
    add_generation_prompt: bool,
    enable_thinking: bool,
) -> list[str]:
    inner = _hf_tokenizer(llm)
    out: list[str] = []
    kw: dict = {}
    if enable_thinking:
        kw["enable_thinking"] = True
    for conv in messages:
        text = inner.apply_chat_template(
            conv,
            tokenize=False,
            add_generation_prompt=add_generation_prompt,
            **kw,
        )
        out.append(text)
    return out


if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s"
    )
    cfg_dir = _SCRIPT_DIR / "longbench_config"
    prompts = json.load(open(cfg_dir / "dataset2prompt.json", "r", encoding="utf-8"))
    max_gen_lens = json.load(open(cfg_dir / "dataset2maxlen.json", "r", encoding="utf-8"))

    datasets = [
        "narrativeqa",
        "qasper",
        "multifieldqa_en",
        "hotpotqa",
        "2wikimqa",
        "musique",
        "gov_report",
        "qmsum",
        "multi_news",
        "trec",
        "triviaqa",
        "samsum",
        "passage_retrieval_en",
        "passage_count",
        "lcc",
        "repobench-p",
    ]
    dataset = concatenate_datasets(
        [
            load_dataset("THUDM/LongBench", n, split="test", trust_remote_code=True)
            for n in datasets
        ]
    ).shuffle(seed=42)

    dset_names = [
        item["dataset"] if item["dataset"][-2:] != "_e" else item["dataset"][:-2]
        for item in dataset
    ]
    gen_lengths = [max_gen_lens[dset_name] for dset_name in dset_names]

    messages = [
        [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompts[dset_name].format(**item)},
        ]
        for dset_name, item in zip(dset_names, dataset)
    ]

    model = os.environ.get("KVPRUNE_EVAL_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
    tp = int(os.environ.get("KVPRUNE_EVAL_TP", "2"))
    seq_ratio = float(os.environ.get("KVPRUNE_SEQ_COMPRESSION_RATIO", "0.25"))

    llm = LLM(
        model=model,
        max_num_seqs=64,
        gpu_memory_utilization=0.95,
        tensor_parallel_size=tp,
        max_model_len=128000,
        kvprune_compression=True,
    )
    text_prompts = messages_to_prompts(
        llm,
        messages,
        add_generation_prompt=True,
        enable_thinking=False,
    )
    sampling_params = [
        SamplingParams(max_tokens=g, temperature=0.00001) for g in gen_lengths
    ]
    n = len(text_prompts)
    compression = [
        CompressionParams(
            compression_ratio=seq_ratio,
            compression_method="compactor",
            protected_first_tokens=8,
            protected_last_tokens=64,
        )
    ] * n

    outputs = llm.generate(text_prompts, sampling_params, compression=compression)
    responses = [o.outputs[0].text for o in outputs]

    results: dict = {}
    for dset_name, prediction, item in zip(dset_names, responses, dataset):
        results.setdefault(dset_name, [])
        pred = prediction
        if dset_name in ["trec", "triviaqa", "samsum", "lsht"]:
            pred = pred.lstrip("\n").split("\n")[0]
        score = 0.0
        for ground_truth in item["answers"]:
            score = max(
                score,
                dataset2metric[dset_name](
                    pred, ground_truth, all_classes=item["all_classes"]
                ),
            )
        results[dset_name].append(score)

    all_sum, all_count = 0, 0
    for task, scores in results.items():
        avg = sum(scores) / len(scores)
        print(task, f"{avg:.2f}")
        all_sum += sum(scores)
        all_count += len(scores)
    print(f"ALL: {all_sum / all_count:.2f}")