gsm8k_eval.py 9.61 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Isolated GSM8K evaluation script for vLLM serve endpoint.
"""

import argparse
import ast
import asyncio
import json
import os
import time
from collections.abc import Generator

import aiohttp
import numpy as np
import regex as re
import requests
from tqdm.asyncio import tqdm

INVALID = -9999999


25
def download_and_cache_file(url: str, filename: str | None = None) -> str:
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    """Download and cache a file from a URL."""
    if filename is None:
        filename = os.path.join("/tmp", url.split("/")[-1])

    if os.path.exists(filename):
        return filename

    print(f"Downloading from {url} to {filename}")
    response = requests.get(url, stream=True)
    response.raise_for_status()

    with open(filename, "wb") as f:
        for chunk in response.iter_content(chunk_size=1024):
            f.write(chunk)

    return filename


def load_gsm8k_data() -> tuple[list[dict], list[dict]]:
    """Load GSM8K train and test data"""
    train_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl"
    test_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"

    train_file = download_and_cache_file(train_url)
    test_file = download_and_cache_file(test_url)

    train_data = list(read_jsonl(train_file))
    test_data = list(read_jsonl(test_file))

    return train_data, test_data


def read_jsonl(filename: str) -> Generator[dict, None, None]:
    """Read a JSONL file."""
    with open(filename) as fin:
        for line in fin:
            if not line.startswith("#"):
                yield json.loads(line)


def get_answer_value(answer_str: str) -> int:
    """Extract the numerical answer from the response."""
    answer_str = answer_str.replace(",", "")
    numbers = re.findall(r"\d+", answer_str)
    if len(numbers) < 1:
        return INVALID
    try:
        return ast.literal_eval(numbers[-1])
    except SyntaxError:
        return INVALID


78
79
80
81
82
async def call_vllm_api(
    session: aiohttp.ClientSession,
    prompt: str,
    temperature: float,
    max_tokens: int,
83
84
85
    stop: list[str] | None = None,
    url: str | None = None,
    seed: int | None = None,
86
87
88
89
90
91
) -> tuple[str, int]:
    """Call vLLM's OpenAI-compatible completions endpoint.

    Returns:
        Tuple of (response_text, completion_tokens)
    """
92
93
94
95
96
97
98
99
100
101
    data = {
        "prompt": prompt,
        "temperature": temperature,
        "max_tokens": max_tokens,
        "stop": stop,
    }
    if seed is not None:
        data["seed"] = seed

    try:
102
        async with session.post(f"{url}/v1/completions", json=data) as response:
103
104
            response.raise_for_status()
            result = await response.json()
105
106
107
            text = result["choices"][0]["text"]
            completion_tokens = result.get("usage", {}).get("completion_tokens", 0)
            return text, completion_tokens
108
109
    except Exception as e:
        print(f"Error calling vLLM API: {e}")
110
        return "", 0
111
112


113
def _build_gsm8k_prompts(
114
115
    num_questions: int = 1319,
    num_shots: int = 5,
116
117
118
119
) -> tuple[list[str], list[int]]:
    """Build few-shot GSM8K completion prompts and ground-truth labels."""
    if num_questions == 0:
        return [], []
120
121
122
123
124
    train_data, test_data = load_gsm8k_data()
    num_questions = min(num_questions, len(test_data))

    few_shot_examples = ""
    for i in range(num_shots):
125
126
127
128
        few_shot_examples += (
            f"Question: {train_data[i]['question']}\n"
            f"Answer: {train_data[i]['answer']}\n\n"
        )
129

130
    prompts = []
131
132
    labels = []
    for i in range(num_questions):
133
134
135
        prompts.append(
            few_shot_examples + f"Question: {test_data[i]['question']}\nAnswer:"
        )
136
137
138
        labels.append(get_answer_value(test_data[i]["answer"]))

    assert all(label != INVALID for label in labels), "Some labels are invalid"
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    return prompts, labels


def _score_gsm8k(
    states: list[str],
    output_tokens: list[int],
    labels: list[int],
    num_shots: int,
    max_tokens: int,
    latency: float,
) -> dict[str, float | int]:
    """Score GSM8K responses and return a results dict."""
    num_questions = len(labels)
    preds = [get_answer_value(state) for state in states]
    accuracy = np.mean(np.array(preds) == np.array(labels))
    invalid_rate = np.mean(np.array(preds) == INVALID)
    total_output_tokens = sum(output_tokens)
    tokens_per_second = total_output_tokens / latency if latency > 0 else 0.0

    return {
        "accuracy": accuracy,
        "invalid_rate": invalid_rate,
        "latency": latency,
        "questions_per_second": num_questions / latency if latency > 0 else 0.0,
        "total_output_tokens": total_output_tokens,
        "tokens_per_second": tokens_per_second,
        "num_questions": num_questions,
        "num_shots": num_shots,
        "max_tokens": max_tokens,
        "timestamp": time.time(),
    }


def evaluate_gsm8k(
    num_questions: int = 1319,
    num_shots: int = 5,
    max_tokens: int = 256,
    host: str = "http://127.0.0.1",
    port: int = 8000,
    temperature: float = 0.0,
    seed: int | None = 42,
) -> dict[str, float | int]:
    """
    Evaluate GSM8K accuracy using vLLM serve endpoint.

    Returns dict with accuracy, invalid_rate, latency, etc.
    """
    base_url = f"{host}:{port}"
    prompts, labels = _build_gsm8k_prompts(num_questions, num_shots)
    num_questions = len(prompts)
189
190
191

    async def run_async_evaluation():
        states: list[str] = [""] * num_questions
192
        output_tokens: list[int] = [0] * num_questions
193

194
195
        async def get_answer(session: aiohttp.ClientSession, i: int) -> tuple[str, int]:
            answer, tokens = await call_vllm_api(
196
                session=session,
197
                prompt=prompts[i],
198
199
200
201
202
203
204
                temperature=temperature,
                max_tokens=max_tokens,
                stop=["Question", "Assistant:", "<|separator|>"],
                url=base_url,
                seed=seed,
            )
            states[i] = answer
205
206
            output_tokens[i] = tokens
            return answer, tokens
207

208
209
210
        async with aiohttp.ClientSession(
            timeout=aiohttp.ClientTimeout(total=600)
        ) as session:
211
212
213
            tasks = [get_answer(session, i) for i in range(num_questions)]
            await tqdm.gather(*tasks, desc="Evaluating")

214
        return states, output_tokens
215

216
    print(f"Running GSM8K evaluation: {num_questions} questions, {num_shots}-shot")
217
218

    tic = time.perf_counter()
219
    states, output_tokens = asyncio.run(run_async_evaluation())
220
221
    latency = time.perf_counter() - tic

222
    return _score_gsm8k(states, output_tokens, labels, num_shots, max_tokens, latency)
223
224


225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def evaluate_gsm8k_offline(
    llm,
    num_questions: int = 1319,
    num_shots: int = 5,
    max_tokens: int = 256,
    temperature: float = 0.0,
) -> dict[str, float | int]:
    """Evaluate GSM8K accuracy using an offline vllm.LLM object.

    Same prompts and scoring as evaluate_gsm8k(), but runs generation
    directly via llm.generate() instead of calling a server over HTTP.
    """
    from vllm import SamplingParams

    prompts, labels = _build_gsm8k_prompts(num_questions, num_shots)

    sampling_params = SamplingParams(
        temperature=temperature,
        max_tokens=max_tokens,
        stop=["Question", "Assistant:", "<|separator|>"],
    )

    print(
        f"Running offline GSM8K evaluation: {len(prompts)} questions, {num_shots}-shot"
    )

    tic = time.perf_counter()
    outputs = llm.generate(prompts, sampling_params)
    latency = time.perf_counter() - tic

    states = [o.outputs[0].text for o in outputs]
    output_tokens = [len(o.outputs[0].token_ids) for o in outputs]

    return _score_gsm8k(states, output_tokens, labels, num_shots, max_tokens, latency)
259
260
261


def main() -> None:
262
263
264
265
266
267
268
269
270
271
272
273
274
275
    parser = argparse.ArgumentParser(description="GSM8K evaluation for vLLM serve")
    parser.add_argument(
        "--num-shots", type=int, default=5, help="Number of few-shot examples"
    )
    parser.add_argument(
        "--num-questions",
        type=int,
        default=1319,
        help="Number of questions to evaluate",
    )
    parser.add_argument(
        "--max-tokens", type=int, default=256, help="Max tokens for generation"
    )
    parser.add_argument("--host", type=str, default="http://127.0.0.1", help="Host URL")
276
    parser.add_argument("--port", type=int, default=8000, help="Port number")
277
278
279
280
281
282
283
    parser.add_argument(
        "--temperature", type=float, default=0.0, help="Temperature for generation"
    )
    parser.add_argument(
        "--seed", type=int, default=42, help="Random seed for reproducibility"
    )
    parser.add_argument("--save-results", type=str, help="Save results to JSON file")
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302

    args = parser.parse_args()

    result = evaluate_gsm8k(
        num_questions=args.num_questions,
        num_shots=args.num_shots,
        max_tokens=args.max_tokens,
        host=args.host,
        port=args.port,
        temperature=args.temperature,
        seed=args.seed,
    )

    # Print results to terminal
    print("\nResults:")
    print(f"Accuracy: {result['accuracy']:.3f}")
    print(f"Invalid responses: {result['invalid_rate']:.3f}")
    print(f"Total latency: {result['latency']:.3f} s")
    print(f"Questions per second: {result['questions_per_second']:.3f}")
303
304
    print(f"Total output tokens: {result['total_output_tokens']}")
    print(f"Output tokens per second: {result['tokens_per_second']:.3f}")
305
306
307
308
309
310
311
312
313
314

    # Optional file saving
    if args.save_results:
        with open(args.save_results, "w") as f:
            json.dump(result, f, indent=2)
        print(f"Results saved to {args.save_results}")


if __name__ == "__main__":
    main()