benchmark_batch.py 7.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import concurrent.futures
import os
import random
import time
from concurrent.futures import ProcessPoolExecutor
from statistics import mean

import requests
from tqdm import tqdm
from transformers import AutoTokenizer

from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint

###############################################################################
# CONFIG
###############################################################################
ENDPOINT_URL = "http://127.0.0.1:30000"
TOKENIZER_DIR = "/models/meta-llama/Llama-3.2-3B"

# Benchmark configurations
NUM_REQUESTS = 10  # Total number of requests (each with BATCH_SIZE prompts)
NUM_TOKENS = 32000  # Tokens per prompt
BATCH_SIZE = 8  # Number of prompts per request
GEN_TOKENS = 0  # Tokens to generate per prompt


###############################################################################
# REQUEST GENERATION (in parallel)
###############################################################################
def generate_random_prompt(index, tokenizer_dir, num_tokens):
    """Generate a single random prompt with specified token count."""
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
    vocab_size = tokenizer.vocab_size

    def generate_random_text(num_toks):
        random_token_ids = [random.randint(0, vocab_size - 1) for _ in range(num_toks)]
        return tokenizer.decode(random_token_ids, clean_up_tokenization_spaces=True)

    random_text = generate_random_text(num_tokens)
    return f"Prompt {index}: {random_text}"


def prepare_all_prompts(num_requests, batch_size, num_tokens, tokenizer_dir):
    """Generate prompts for all requests in parallel."""
    total_prompts = num_requests * batch_size
    all_prompts = [None] * total_prompts
    max_workers = min(os.cpu_count() or 1, total_prompts)

    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        futures = [
            executor.submit(generate_random_prompt, i, tokenizer_dir, num_tokens)
            for i in range(total_prompts)
        ]
        for future in tqdm(
            concurrent.futures.as_completed(futures),
            total=total_prompts,
            desc="Generating prompts",
        ):
            index = futures.index(future)
            all_prompts[index] = future.result()

    batched_prompts = [
        all_prompts[i * batch_size : (i + 1) * batch_size] for i in range(num_requests)
    ]

    print(
        f"Generated {total_prompts} prompts with {num_tokens} tokens each, grouped into {num_requests} requests of {batch_size} prompts.\n"
    )
    return batched_prompts


###############################################################################
# HTTP CALLS
###############################################################################
def send_batch_request(endpoint, prompts, gen_tokens, request_id):
    """Send a batch of prompts to the /generate endpoint synchronously."""
    sampling_params = {
        "max_new_tokens": gen_tokens,
        "temperature": 0.7,
        "stop": "\n",
    }
    data = {"text": prompts, "sampling_params": sampling_params}

    start_time = time.time()
    try:
        response = requests.post(
            endpoint.base_url + "/generate", json=data, timeout=3600
        )
        if response.status_code != 200:
            error = response.json()
            raise RuntimeError(f"Request {request_id} failed: {error}")
        result = response.json()
        elapsed_time = (time.time() - start_time) * 1000  # Convert to ms
        avg_per_prompt = elapsed_time / len(prompts) if prompts else 0
        return request_id, elapsed_time, avg_per_prompt, True, len(prompts)
    except Exception as e:
        print(f"[Request] Error for request {request_id}: {e}")
        return request_id, 0, 0, False, len(prompts)


def run_benchmark(endpoint, batched_prompts, batch_size, gen_tokens):
    """Run the benchmark sequentially."""
    results = []
    num_requests = len(batched_prompts)

    # Record start time for total latency
    benchmark_start_time = time.time()

    for i, batch_prompts in enumerate(batched_prompts):
        request_id = i + 1
        assert (
            len(batch_prompts) == batch_size
        ), f"Request {request_id} should have {batch_size} prompts, got {len(batch_prompts)}"

        print(
            f"[Request] Sending request {request_id}/{num_requests} with {len(batch_prompts)} prompts at {int(time.time()*1000)}"
        )
        result = send_batch_request(endpoint, batch_prompts, gen_tokens, request_id)
        results.append(result)

    # Calculate total latency
    total_latency = (time.time() - benchmark_start_time) * 1000  # Convert to ms

    return results, total_latency


###############################################################################
# RESULTS
###############################################################################
def process_results(results, total_latency, num_requests):
    """Process and display benchmark results."""
    total_time = 0
    successful_requests = 0
    failed_requests = 0
    request_latencies = []
    per_prompt_latencies = []
    total_prompts = 0

    for request_id, elapsed_time, avg_per_prompt, success, batch_size in results:
        if success:
            successful_requests += 1
            total_prompts += batch_size
            request_latencies.append(elapsed_time)
            per_prompt_latencies.append(avg_per_prompt)
            total_time += elapsed_time / 1000  # Convert to seconds
        else:
            failed_requests += 1

    avg_request_latency = mean(request_latencies) if request_latencies else 0
    avg_per_prompt_latency = mean(per_prompt_latencies) if per_prompt_latencies else 0
    throughput = total_prompts / total_time if total_time > 0 else 0

    print("\nBenchmark Summary:")
    print(f"  Total requests sent:         {len(results)}")
    print(f"  Total prompts sent:          {total_prompts}")
    print(f"  Successful requests:         {successful_requests}")
    print(f"  Failed requests:             {failed_requests}")
    print(f"  Total latency (all requests): {total_latency:.2f} ms")
    print(f"  Avg per request latency:     {avg_request_latency:.2f} ms")
    print(f"  Avg per prompt latency:      {avg_per_prompt_latency:.2f} ms")
    print(f"  Throughput:                  {throughput:.2f} prompts/second\n")


###############################################################################
# MAIN
###############################################################################
def main():
    # Initialize endpoint
    endpoint = RuntimeEndpoint(ENDPOINT_URL)

    # Generate prompts
    batched_prompts = prepare_all_prompts(
        NUM_REQUESTS, BATCH_SIZE, NUM_TOKENS, TOKENIZER_DIR
    )

    # Flush cache before benchmark
    # endpoint.flush_cache()

    # Run benchmark
    print(
        f"Starting benchmark: NUM_TOKENS={NUM_TOKENS}, BATCH_SIZE={BATCH_SIZE}, NUM_REQUESTS={NUM_REQUESTS}\n"
    )
    results, total_latency = run_benchmark(
        endpoint, batched_prompts, BATCH_SIZE, GEN_TOKENS
    )

    # Process and display results
    process_results(results, total_latency, NUM_REQUESTS)


if __name__ == "__main__":
    random.seed(0)
    main()