loogle_eval.py 4.23 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import argparse
import asyncio
import os
import pickle
from pathlib import Path
from typing import List

import openai
import torch
from bert_score import BERTScorer
from datasets import load_dataset
from tqdm import tqdm


def get_client(api_url: str) -> openai.AsyncOpenAI:
    if os.getenv("OPENAI_API_KEY") is None:
        os.environ["OPENAI_API_KEY"] = "EMPTY"
    return openai.AsyncOpenAI(base_url=api_url)


def get_dataset():
    return load_dataset("bigai-nlco/LooGLE", "longdep_qa", split="test")


async def fetch_response(
    client: openai.AsyncOpenAI,
    context: str,
    question: str,
    semaphore: asyncio.Semaphore,
    index: int,
    model: str,
    output_dir: Path,
):
    output_file = output_dir / f"response_{index}.pkl"
    if output_file.exists():
        return

    prompt = (
        "Please answer the question based on the long texts below.\n"
        f"{context}\n"
        f"Question: {question}\n"
        "Answer:"
    )
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": prompt},
    ]

    async with semaphore:
        try:
            response = await client.chat.completions.create(
                model=model,
                messages=messages,
                temperature=0.0,
                max_tokens=512,
            )
        except openai.BadRequestError as e:
            with open(output_file, "wb") as f:
                pickle.dump({"error": str(e)}, f)
            return

    with open(output_file, "wb") as f:
        pickle.dump(response, f)


async def benchmark(args):
    dataset = get_dataset()
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    client = get_client(args.api_url)
    semaphore = asyncio.Semaphore(args.max_concurrency)

    tasks: List[asyncio.Task] = []
    for idx, ex in enumerate(dataset):
        tasks.append(
            asyncio.create_task(
                fetch_response(
                    client,
                    ex["context"],
                    ex["question"],
                    semaphore,
                    idx,
                    args.model,
                    output_dir,
                )
            )
        )

    for _ in tqdm(
        asyncio.as_completed(tasks), total=len(tasks), desc="Running benchmark"
    ):
        await _


def analyse(args):
    dataset = get_dataset()
    output_dir = Path(args.output_dir)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    scorer = BERTScorer(lang="en", device=device)

    hyps: List[str] = []
    refs: List[str] = []
    for idx, ex in enumerate(tqdm(dataset, desc="Loading responses")):
        pkl_file = output_dir / f"response_{idx}.pkl"
        if not pkl_file.exists():
            raise FileNotFoundError(pkl_file)

        response = pickle.load(open(pkl_file, "rb"))
        if isinstance(response, dict) and "error" in response:
            continue

        hyps.append(response.choices[0].message.content.strip())
        refs.append(ex["answer"])

    if not hyps:
        print("No valid responses to score!")
        return

    batch_size = 64
    all_f1: List[float] = []
    for i in tqdm(range(0, len(hyps), batch_size), desc="Scoring batches"):
        h_batch = hyps[i : i + batch_size]
        r_batch = refs[i : i + batch_size]
        _, _, f1_scores = scorer.score(h_batch, r_batch, verbose=False)
        all_f1.extend([float(x) for x in f1_scores])

    avg = sum(all_f1) / len(all_f1)
    print(f"Average BERTScore (F1): {avg:.2%}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Run benchmark and evaluation in one go."
    )
    parser.add_argument(
        "--api-url",
        default="http://127.0.0.1:30000/v1",
        help="OpenAI‑compatible API base URL",
    )
    parser.add_argument(
        "--model",
        default="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
        help="Model name or ID, only used for model name",
    )
    parser.add_argument(
        "--max-concurrency", type=int, default=144, help="Maximum concurrent requests"
    )
    parser.add_argument(
        "--output-dir", default="tmp-output-dir", help="Directory for cached responses"
    )
    args = parser.parse_args()

    asyncio.run(benchmark(args))

    analyse(args)