gsm8k_eval.py 8.06 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Isolated GSM8K evaluation script for vLLM serve endpoint.
"""

import argparse
import ast
import asyncio
import json
import os
import time
from collections.abc import Generator

import aiohttp
import numpy as np
import regex as re
import requests
from tqdm.asyncio import tqdm

INVALID = -9999999


25
def download_and_cache_file(url: str, filename: str | None = None) -> str:
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    """Download and cache a file from a URL."""
    if filename is None:
        filename = os.path.join("/tmp", url.split("/")[-1])

    if os.path.exists(filename):
        return filename

    print(f"Downloading from {url} to {filename}")
    response = requests.get(url, stream=True)
    response.raise_for_status()

    with open(filename, "wb") as f:
        for chunk in response.iter_content(chunk_size=1024):
            f.write(chunk)

    return filename


def load_gsm8k_data() -> tuple[list[dict], list[dict]]:
    """Load GSM8K train and test data"""
    train_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl"
    test_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"

    train_file = download_and_cache_file(train_url)
    test_file = download_and_cache_file(test_url)

    train_data = list(read_jsonl(train_file))
    test_data = list(read_jsonl(test_file))

    return train_data, test_data


def read_jsonl(filename: str) -> Generator[dict, None, None]:
    """Read a JSONL file."""
    with open(filename) as fin:
        for line in fin:
            if not line.startswith("#"):
                yield json.loads(line)


def get_answer_value(answer_str: str) -> int:
    """Extract the numerical answer from the response."""
    answer_str = answer_str.replace(",", "")
    numbers = re.findall(r"\d+", answer_str)
    if len(numbers) < 1:
        return INVALID
    try:
        return ast.literal_eval(numbers[-1])
    except SyntaxError:
        return INVALID


78
79
80
81
82
async def call_vllm_api(
    session: aiohttp.ClientSession,
    prompt: str,
    temperature: float,
    max_tokens: int,
83
84
85
    stop: list[str] | None = None,
    url: str | None = None,
    seed: int | None = None,
86
87
88
89
90
91
) -> tuple[str, int]:
    """Call vLLM's OpenAI-compatible completions endpoint.

    Returns:
        Tuple of (response_text, completion_tokens)
    """
92
93
94
95
96
97
98
99
100
101
    data = {
        "prompt": prompt,
        "temperature": temperature,
        "max_tokens": max_tokens,
        "stop": stop,
    }
    if seed is not None:
        data["seed"] = seed

    try:
102
        async with session.post(f"{url}/v1/completions", json=data) as response:
103
104
            response.raise_for_status()
            result = await response.json()
105
106
107
            text = result["choices"][0]["text"]
            completion_tokens = result.get("usage", {}).get("completion_tokens", 0)
            return text, completion_tokens
108
109
    except Exception as e:
        print(f"Error calling vLLM API: {e}")
110
        return "", 0
111
112


113
114
115
116
117
118
119
def evaluate_gsm8k(
    num_questions: int = 1319,
    num_shots: int = 5,
    max_tokens: int = 256,
    host: str = "http://127.0.0.1",
    port: int = 8000,
    temperature: float = 0.0,
120
121
    seed: int | None = 42,
) -> dict[str, float | int]:
122
123
    """
    Evaluate GSM8K accuracy using vLLM serve endpoint.
124

125
126
127
128
129
130
131
132
133
134
135
136
137
    Returns dict with accuracy, invalid_rate, latency, etc.
    """
    base_url = f"{host}:{port}"

    # Load GSM8K train and test data
    train_data, test_data = load_gsm8k_data()

    # Limit to available test questions
    num_questions = min(num_questions, len(test_data))

    # Build few-shot examples from train split (like lm-eval does)
    few_shot_examples = ""
    for i in range(num_shots):
138
139
140
141
        few_shot_examples += (
            f"Question: {train_data[i]['question']}\n"
            f"Answer: {train_data[i]['answer']}\n\n"
        )
142
143
144
145
146
147
148
149
150
151
152
153
154

    # Prepare test questions and labels from test split
    questions = []
    labels = []
    for i in range(num_questions):
        questions.append(f"Question: {test_data[i]['question']}\nAnswer:")
        labels.append(get_answer_value(test_data[i]["answer"]))

    assert all(label != INVALID for label in labels), "Some labels are invalid"

    # Run evaluation
    async def run_async_evaluation():
        states: list[str] = [""] * num_questions
155
        output_tokens: list[int] = [0] * num_questions
156

157
        async def get_answer(session: aiohttp.ClientSession, i: int) -> tuple[str, int]:
158
            prompt = few_shot_examples + questions[i]
159
            answer, tokens = await call_vllm_api(
160
161
162
163
164
165
166
167
168
                session=session,
                prompt=prompt,
                temperature=temperature,
                max_tokens=max_tokens,
                stop=["Question", "Assistant:", "<|separator|>"],
                url=base_url,
                seed=seed,
            )
            states[i] = answer
169
170
            output_tokens[i] = tokens
            return answer, tokens
171

172
173
174
        async with aiohttp.ClientSession(
            timeout=aiohttp.ClientTimeout(total=600)
        ) as session:
175
176
177
            tasks = [get_answer(session, i) for i in range(num_questions)]
            await tqdm.gather(*tasks, desc="Evaluating")

178
        return states, output_tokens
179

180
    print(f"Running GSM8K evaluation: {num_questions} questions, {num_shots}-shot")
181
182

    tic = time.perf_counter()
183
    states, output_tokens = asyncio.run(run_async_evaluation())
184
185
186
187
188
189
    latency = time.perf_counter() - tic

    # Compute metrics
    preds = [get_answer_value(state) for state in states]
    accuracy = np.mean(np.array(preds) == np.array(labels))
    invalid_rate = np.mean(np.array(preds) == INVALID)
190
191
    total_output_tokens = sum(output_tokens)
    tokens_per_second = total_output_tokens / latency if latency > 0 else 0.0
192
193
194
195
196
197

    result = {
        "accuracy": accuracy,
        "invalid_rate": invalid_rate,
        "latency": latency,
        "questions_per_second": num_questions / latency,
198
199
        "total_output_tokens": total_output_tokens,
        "tokens_per_second": tokens_per_second,
200
201
202
203
204
205
206
207
208
209
        "num_questions": num_questions,
        "num_shots": num_shots,
        "max_tokens": max_tokens,
        "timestamp": time.time(),
    }

    return result


def main() -> None:
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    parser = argparse.ArgumentParser(description="GSM8K evaluation for vLLM serve")
    parser.add_argument(
        "--num-shots", type=int, default=5, help="Number of few-shot examples"
    )
    parser.add_argument(
        "--num-questions",
        type=int,
        default=1319,
        help="Number of questions to evaluate",
    )
    parser.add_argument(
        "--max-tokens", type=int, default=256, help="Max tokens for generation"
    )
    parser.add_argument("--host", type=str, default="http://127.0.0.1", help="Host URL")
224
    parser.add_argument("--port", type=int, default=8000, help="Port number")
225
226
227
228
229
230
231
    parser.add_argument(
        "--temperature", type=float, default=0.0, help="Temperature for generation"
    )
    parser.add_argument(
        "--seed", type=int, default=42, help="Random seed for reproducibility"
    )
    parser.add_argument("--save-results", type=str, help="Save results to JSON file")
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250

    args = parser.parse_args()

    result = evaluate_gsm8k(
        num_questions=args.num_questions,
        num_shots=args.num_shots,
        max_tokens=args.max_tokens,
        host=args.host,
        port=args.port,
        temperature=args.temperature,
        seed=args.seed,
    )

    # Print results to terminal
    print("\nResults:")
    print(f"Accuracy: {result['accuracy']:.3f}")
    print(f"Invalid responses: {result['invalid_rate']:.3f}")
    print(f"Total latency: {result['latency']:.3f} s")
    print(f"Questions per second: {result['questions_per_second']:.3f}")
251
252
    print(f"Total output tokens: {result['total_output_tokens']}")
    print(f"Output tokens per second: {result['tokens_per_second']:.3f}")
253
254
255
256
257
258
259
260
261
262

    # Optional file saving
    if args.save_results:
        with open(args.save_results, "w") as f:
            json.dump(result, f, indent=2)
        print(f"Results saved to {args.save_results}")


if __name__ == "__main__":
    main()