""" Benchmark the latency of running a single batch with a server. This script launches a server and uses the HTTP interface. It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths). Usage: python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8 python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage """ import argparse import dataclasses import itertools import json import multiprocessing import os import time from typing import List, Tuple import requests from sglang.bench_serving import get_tokenizer, sample_random_requests from sglang.profiler import run_profile from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.server_args import ServerArgs from sglang.srt.utils import is_blackwell, kill_process_tree from sglang.test.test_utils import is_in_ci, write_github_step_summary @dataclasses.dataclass class BenchArgs: run_name: str = "default" batch_size: Tuple[int] = (1,) input_len: Tuple[int] = (1024,) output_len: Tuple[int] = (16,) temperature: float = 0.0 return_logprob: bool = False client_stream_interval: int = 1 input_len_step_percentage: float = 0.0 result_filename: str = "result.jsonl" base_url: str = "" skip_warmup: bool = False show_report: bool = False profile: bool = False profile_steps: int = 3 profile_by_stage: bool = False dataset_path: str = "" @staticmethod def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument("--run-name", type=str, default=BenchArgs.run_name) parser.add_argument( "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size ) parser.add_argument( "--input-len", type=int, nargs="+", default=BenchArgs.input_len ) parser.add_argument( "--output-len", type=int, nargs="+", default=BenchArgs.output_len ) parser.add_argument("--temperature", type=float, default=BenchArgs.temperature) parser.add_argument("--return-logprob", action="store_true") parser.add_argument( "--client-stream-interval", type=int, default=BenchArgs.client_stream_interval, ) parser.add_argument( "--input-len-step-percentage", type=float, default=BenchArgs.input_len_step_percentage, ) parser.add_argument( "--result-filename", type=str, default=BenchArgs.result_filename ) parser.add_argument("--base-url", type=str, default=BenchArgs.base_url) parser.add_argument("--skip-warmup", action="store_true") parser.add_argument("--show-report", action="store_true") parser.add_argument("--profile", action="store_true") parser.add_argument( "--profile-steps", type=int, default=BenchArgs.profile_steps ) parser.add_argument("--profile-by-stage", action="store_true") parser.add_argument( "--dataset-path", type=str, default=BenchArgs.dataset_path, help="Path to the dataset.", ) @classmethod def from_cli_args(cls, args: argparse.Namespace): # use the default value's type to cast the args into correct types. attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)] return cls( **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs} ) def launch_server_internal(server_args): try: launch_server(server_args) except Exception as e: raise e finally: kill_process_tree(os.getpid(), include_parent=False) def launch_server_process(server_args: ServerArgs): proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,)) proc.start() base_url = f"http://{server_args.host}:{server_args.port}" timeout = 600 start_time = time.time() while time.time() - start_time < timeout: try: headers = { "Content-Type": "application/json; charset=utf-8", } response = requests.get(f"{base_url}/v1/models", headers=headers) if response.status_code == 200: return proc, base_url except requests.RequestException: pass time.sleep(10) raise TimeoutError("Server failed to start within the timeout period.") def run_one_case( url: str, batch_size: int, input_len: int, output_len: int, temperature: float, return_logprob: bool, stream_interval: int, input_len_step_percentage: float, run_name: str, result_filename: str, tokenizer, profile: bool = False, profile_steps: int = 3, profile_by_stage: bool = False, dataset_path: str = "", ): requests.post(url + "/flush_cache") input_requests = sample_random_requests( input_len=input_len, output_len=output_len, num_prompts=batch_size, range_ratio=1.0, tokenizer=tokenizer, dataset_path=dataset_path, random_sample=True, return_text=False, ) use_structured_outputs = False if use_structured_outputs: texts = [] for _ in range(batch_size): texts.append( "Human: What is the capital city of france? can you give as many trivial information as possible about that city? answer in json.\n" * 50 + "Assistant:" ) json_schema = "$$ANY$$" else: json_schema = None profile_link = None if profile: profile_link: str = run_profile( url, profile_steps, ["CPU", "GPU"], None, None, profile_by_stage ) tic = time.perf_counter() response = requests.post( url + "/generate", json={ "input_ids": [req.prompt for req in input_requests], "sampling_params": { "temperature": temperature, "max_new_tokens": output_len, "ignore_eos": True, "json_schema": json_schema, "stream_interval": stream_interval, }, "return_logprob": return_logprob, "stream": True, }, stream=True, ) # The TTFT of the last request in the batch ttft = 0.0 for chunk in response.iter_lines(decode_unicode=False): chunk = chunk.decode("utf-8") if chunk and chunk.startswith("data:"): if chunk == "data: [DONE]": break data = json.loads(chunk[5:].strip("\n")) if "error" in data: raise RuntimeError(f"Request has failed. {data}.") assert ( data["meta_info"]["finish_reason"] is None or data["meta_info"]["finish_reason"]["type"] == "length" ) if data["meta_info"]["completion_tokens"] == 1: ttft = time.perf_counter() - tic latency = time.perf_counter() - tic input_throughput = batch_size * input_len / ttft output_throughput = batch_size * output_len / (latency - ttft) overall_throughput = batch_size * (input_len + output_len) / latency server_info = requests.get(url + "/get_server_info").json() acc_length = server_info["internal_states"][0].get("avg_spec_accept_length", None) last_gen_throughput = server_info["internal_states"][0]["last_gen_throughput"] print(f"batch size: {batch_size}") print(f"input_len: {input_len}") print(f"output_len: {output_len}") print(f"latency: {latency:.2f} s") print(f"ttft: {ttft:.2f} s") print(f"last generation throughput: {last_gen_throughput:.2f} tok/s") print(f"input throughput: {input_throughput:.2f} tok/s") if output_len != 1: print(f"output throughput: {output_throughput:.2f} tok/s") if result_filename: with open(result_filename, "a") as fout: res = { "run_name": run_name, "batch_size": batch_size, "input_len": input_len, "output_len": output_len, "latency": round(latency, 4), "output_throughput": round(output_throughput, 2), "overall_throughput": round(overall_throughput, 2), "last_gen_throughput": round(last_gen_throughput, 2), } fout.write(json.dumps(res) + "\n") return ( batch_size, latency, ttft, input_throughput, output_throughput, overall_throughput, last_gen_throughput, acc_length, profile_link if profile else None, ) def get_report_summary( result: List[Tuple], server_args: ServerArgs, bench_args: BenchArgs ): import tabulate summary = ( f"\nInput lens: {bench_args.input_len}. Output lens: {bench_args.output_len}.\n" ) headers = [ "batch size", "latency (s)", "input throughput (tok/s)", "output throughput (tok/s)", "acc length", "ITL (ms)", "input cost ($/1M)", "output cost ($/1M)", ] if bench_args.profile: headers.append("profile") rows = [] for ( batch_size, latency, ttft, input_throughput, output_throughput, _, _, acc_length, trace_link, ) in result: if is_blackwell(): hourly_cost_per_gpu = 4 # $4/hour for one B200 else: hourly_cost_per_gpu = 2 # $2/hour for one H100 hourly_cost = hourly_cost_per_gpu * server_args.tp_size input_util = 0.7 accept_length = round(acc_length, 2) if acc_length is not None else "n/a" itl = 1 / (output_throughput / batch_size) * 1000 input_cost = 1e6 / (input_throughput * input_util) / 3600 * hourly_cost output_cost = 1e6 / output_throughput / 3600 * hourly_cost row = [ batch_size, latency, input_throughput, output_throughput, accept_length, itl, input_cost, output_cost, ] if trace_link: row.append(f"[Profile]({trace_link})") rows.append(row) summary += tabulate.tabulate( rows, headers=headers, tablefmt="github", floatfmt=".2f" ) return summary def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): if bench_args.base_url: proc, base_url = None, bench_args.base_url else: proc, base_url = launch_server_process(server_args) server_info = requests.get(base_url + "/get_server_info").json() if "tokenizer_path" in server_info: tokenizer_path = server_info["tokenizer_path"] elif "prefill" in server_info: tokenizer_path = server_info["prefill"][0]["tokenizer_path"] tokenizer = get_tokenizer(tokenizer_path) # warmup if not bench_args.skip_warmup: print("=" * 8 + " Warmup Begin " + "=" * 8) run_one_case( base_url, batch_size=16, input_len=1024, output_len=16, temperature=bench_args.temperature, return_logprob=bench_args.return_logprob, stream_interval=bench_args.client_stream_interval, input_len_step_percentage=bench_args.input_len_step_percentage, run_name="", result_filename="", tokenizer=tokenizer, dataset_path=bench_args.dataset_path, ) print("=" * 8 + " Warmup End " + "=" * 8 + "\n") # benchmark result = [] bench_result = [] try: for bs, il, ol in itertools.product( bench_args.batch_size, bench_args.input_len, bench_args.output_len ): result.append( run_one_case( base_url, bs, il, ol, temperature=bench_args.temperature, return_logprob=bench_args.return_logprob, stream_interval=bench_args.client_stream_interval, input_len_step_percentage=bench_args.input_len_step_percentage, run_name=bench_args.run_name, result_filename=bench_args.result_filename, tokenizer=tokenizer, dataset_path=bench_args.dataset_path, ) ) if bench_args.profile: try: for bs, il, ol in itertools.product( bench_args.batch_size, bench_args.input_len, bench_args.output_len ): bench_result.append( ( run_one_case( base_url, bs, il, ol, temperature=bench_args.temperature, return_logprob=bench_args.return_logprob, stream_interval=bench_args.client_stream_interval, input_len_step_percentage=bench_args.input_len_step_percentage, run_name=bench_args.run_name, result_filename=bench_args.result_filename, tokenizer=tokenizer, profile=bench_args.profile, profile_steps=bench_args.profile_steps, profile_by_stage=bench_args.profile_by_stage, dataset_path=bench_args.dataset_path, )[-1], ) ) result = [t1[:-1] + t2 for t1, t2 in zip(result, bench_result)] except Exception as e: print(f"Error profiling, there will be no profile trace dump: {e}") finally: if proc: kill_process_tree(proc.pid) print(f"\nResults are saved to {bench_args.result_filename}") if not bench_args.show_report: return summary = get_report_summary(result, server_args, bench_args) print(summary) if is_in_ci(): write_github_step_summary(summary) def main(): parser = argparse.ArgumentParser() ServerArgs.add_cli_args(parser) BenchArgs.add_cli_args(parser) args = parser.parse_args() server_args = ServerArgs.from_cli_args(args) bench_args = BenchArgs.from_cli_args(args) run_benchmark(server_args, bench_args) if __name__ == "__main__": main()