eval_longbench.py 3.71 KB
Newer Older
chenzk's avatar
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import json
import logging

from datasets import concatenate_datasets, load_dataset

from compactor_vllm import (
    LLM,
    LLMConfig,
    SamplingParams,
)
from compactor_vllm.compression import (
    BatchCompressionParams,
    CompressionMethod,
    SequenceCompressionParams,
)
from compactor_vllm.config.engine_config import AttentionBackend
from longbench_metrics import dataset2metric

if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s"
    )
    datasets = [
        "narrativeqa",
        "qasper",
        "multifieldqa_en",
        "hotpotqa",
        "2wikimqa",
        "musique",
        "gov_report",
        "qmsum",
        "multi_news",
        "trec",
        "triviaqa",
        "samsum",
        "passage_retrieval_en",
        "passage_count",
        "lcc",
        "repobench-p",
    ]
    dataset = concatenate_datasets(
        [
            load_dataset("THUDM/LongBench", n, split="test", trust_remote_code=True)
            for n in datasets
        ]
    ).shuffle(seed=42)

    # dataset = dataset.take(200)
    prompts = json.load(open("longbench_config/dataset2prompt.json", "r"))
    max_gen_lens = json.load(open("longbench_config/dataset2maxlen.json", "r"))

    tokenizer_kwargs = {"add_generation_prompt": True, "enable_thinking": False}
    dset_names = [
        item["dataset"] if item["dataset"][-2:] != "_e" else item["dataset"][:-2]
        for item in dataset
    ]
    gen_lengths = [max_gen_lens[dset_name] for dset_name in dset_names]

    messages = [
        [
            {
                "role": "system",
                "content": "You are a helpful assistant.",
            },
            {"role": "user", "content": prompts[dset_name].format(**item)},
        ]
        for dset_name, item in zip(dset_names, dataset)
    ]
    # model = "Qwen/Qwen3-8B"
    model = "meta-llama/Llama-3.1-8B-Instruct"
    # model = "Qwen/Qwen3-30B-A3B-Instruct-2507"
    config = LLMConfig(
        model,
        max_num_seqs=64,
        gpu_memory_utilization=0.95,
        tensor_parallel_size=2,
        max_model_len=128000,
        attention_backend=AttentionBackend.COMPACTOR_TRITON,
        leverage_sketch_size=32,
    )
    llm = LLM(config)
    responses = llm.generate_chat(
        messages,
        [SamplingParams(max_new_tokens=g, temperature=0.00001) for g in gen_lengths],
        BatchCompressionParams(
            compression_method=CompressionMethod.COMPACTOR,
            do_chunked_compression=False,
            chunk_size=4096,
        ),
        per_sequence_compression_params=[
            SequenceCompressionParams(
                0.25, protected_first_tokens=8, protected_last_tokens=64
            )
        ]
        * len(messages),
        tokenizer_kwargs=tokenizer_kwargs,
        return_sequences=False,
    )
    results = {}
    for dset_name, prediction, item in zip(dset_names, responses, dataset):
        if dset_name not in results:
            results[dset_name] = []

        score = 0.0
        if dset_name in ["trec", "triviaqa", "samsum", "lsht"]:
            prediction = prediction.lstrip("\n").split("\n")[0]

        for ground_truth in item["answers"]:
            score = max(
                score,
                dataset2metric[dset_name](
                    prediction, ground_truth, all_classes=item["all_classes"]
                ),
            )
        results[dset_name].append(score)

    all_sum, all_count = 0, 0
    for task, scores in results.items():
        this_task_sum = sum(scores)
        this_task_count = len(scores)
        print(task, f"{this_task_sum / this_task_count:.2f}")
        all_sum += sum(scores)
        all_count += this_task_count
    print(f"ALL: {all_sum / all_count:.2f}")