Unverified Commit 3abc3036 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[ci] add router benchmark script and CI (#7498)

parent afeed465
name: PR Benchmark (Rust Router)
on:
push:
branches: [ main ]
paths:
- "sgl-router/**"
pull_request:
branches: [ main ]
paths:
- "sgl-router/**"
workflow_dispatch:
concurrency:
group: pr-benchmark-rust-${{ github.ref }}
cancel-in-progress: true
permissions:
contents: read
pull-requests: write
issues: write
jobs:
benchmark-router:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
# Fetch enough history for baseline comparison
fetch-depth: 100
- name: Install dependencies
run: |
bash scripts/ci_install_rust.sh
- name: Cache Rust dependencies
uses: actions/cache@v4
with:
path: |
~/.cargo/bin/
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
sgl-router/target/
key: ${{ runner.os }}-cargo-${{ hashFiles('sgl-router/Cargo.lock') }}
restore-keys: |
${{ runner.os }}-cargo-
- name: Build router in release mode
run: |
source "$HOME/.cargo/env"
cd sgl-router/
cargo build --release
- name: Run quick benchmarks
timeout-minutes: 15
run: |
source "$HOME/.cargo/env"
cd sgl-router/
# Run quick benchmarks for PR validation using Python script
python3 scripts/run_benchmarks.py --quick --validate-thresholds --save-results
- name: Upload benchmark results
if: always()
uses: actions/upload-artifact@v4
with:
name: benchmark-results-${{ github.sha }}
path: |
sgl-router/target/criterion/
retention-days: 30
- name: Post benchmark results as PR comment
if: github.event_name == 'pull_request'
run: |
cd sgl-router/
# Use Python script to post benchmark comment
python3 scripts/post_benchmark_comment.py \
--pr-number ${{ github.event.number }} \
--commit-sha ${{ github.sha }} \
--results-file benchmark_results.env
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
benchmark-integration-test:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install dependencies
run: |
bash scripts/ci_install_rust.sh
- name: Cache Rust dependencies
uses: actions/cache@v4
with:
path: |
~/.cargo/bin/
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
sgl-router/target/
key: ${{ runner.os }}-cargo-${{ hashFiles('sgl-router/Cargo.lock') }}
restore-keys: |
${{ runner.os }}-cargo-
- name: Run benchmark integration tests
timeout-minutes: 10
run: |
source "$HOME/.cargo/env"
cd sgl-router/
# Run integration tests to ensure benchmark code compiles and works
cargo test --test benchmark_integration
- name: Verify benchmark compilation
run: |
source "$HOME/.cargo/env"
cd sgl-router/
# Ensure all benchmarks compile without running them
cargo check --benches
......@@ -40,6 +40,20 @@ jobs:
cd sgl-router/
cargo test
- name: Check benchmark compilation
run: |
source "$HOME/.cargo/env"
cd sgl-router/
cargo check --benches
- name: Quick benchmark sanity check
timeout-minutes: 10
run: |
source "$HOME/.cargo/env"
cd sgl-router/
# Run quick benchmarks to ensure they work using Python script
python3 scripts/run_benchmarks.py --quick
e2e-python:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on: 2-gpu-runner
......
......@@ -36,6 +36,15 @@ metrics-exporter-prometheus = "0.17.0"
# Added for request tracing
uuid = { version = "1.10", features = ["v4", "serde"] }
thiserror = "2.0.12"
[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }
[[bench]]
name = "request_processing"
harness = false
path = "benches/request_processing.rs"
[profile.release]
lto = "thin"
codegen-units = 1
# SGLang Router Makefile
# Provides convenient shortcuts for common development tasks
.PHONY: help bench bench-quick bench-baseline bench-compare test build clean
help: ## Show this help message
@echo "SGLang Router Development Commands"
@echo "=================================="
@echo ""
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-20s\033[0m %s\n", $$1, $$2}'
@echo ""
build: ## Build the project in release mode
@echo "Building SGLang Router..."
@cargo build --release
test: ## Run all tests
@echo "Running tests..."
@cargo test
bench: ## Run full benchmark suite
@echo "Running full benchmarks..."
@python3 scripts/run_benchmarks.py
bench-quick: ## Run quick benchmarks only
@echo "Running quick benchmarks..."
@python3 scripts/run_benchmarks.py --quick
bench-baseline: ## Save current performance as baseline
@echo "Saving performance baseline..."
@python3 scripts/run_benchmarks.py --save-baseline main
bench-compare: ## Compare with saved baseline
@echo "Comparing with baseline..."
@python3 scripts/run_benchmarks.py --compare-baseline main
bench-ci: ## Run benchmarks suitable for CI (quick mode)
@echo "Running CI benchmarks..."
@python3 scripts/run_benchmarks.py --quick
clean: ## Clean build artifacts
@echo "Cleaning build artifacts..."
@cargo clean
docs: ## Generate and open documentation
@echo "Generating documentation..."
@cargo doc --open
check: ## Run cargo check and clippy
@echo "Running cargo check..."
@cargo check
@echo "Running clippy..."
@cargo clippy
fmt: ## Format code with rustfmt
@echo "Formatting code..."
@cargo fmt
# Development workflow shortcuts
dev-setup: build test ## Set up development environment
@echo "Development environment ready!"
pre-commit: fmt check test bench-quick ## Run pre-commit checks
@echo "Pre-commit checks passed!"
# Benchmark analysis shortcuts
bench-report: ## Open benchmark HTML report
@if [ -f "target/criterion/request_processing/report/index.html" ]; then \
echo "Opening benchmark report..."; \
if command -v xdg-open >/dev/null 2>&1; then \
xdg-open target/criterion/request_processing/report/index.html; \
elif command -v open >/dev/null 2>&1; then \
open target/criterion/request_processing/report/index.html; \
else \
echo "Please open target/criterion/request_processing/report/index.html in your browser"; \
fi \
else \
echo "No benchmark report found. Run 'make bench' first."; \
fi
bench-clean: ## Clean benchmark results
@echo "Cleaning benchmark results..."
@rm -rf target/criterion
# Performance monitoring
perf-monitor: ## Run continuous performance monitoring
@echo "Starting performance monitoring..."
@if command -v watch >/dev/null 2>&1; then \
watch -n 300 'make bench-quick'; \
else \
echo "Warning: 'watch' command not found. Install it or run 'make bench-quick' manually."; \
fi
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use serde_json::{from_str, to_string, to_vec};
use std::time::Instant;
use sglang_router_rs::openai_api_types::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
SamplingParams, StringOrArray, UserMessageContent,
};
use sglang_router_rs::request_adapter::{RouteableRequest, ToPdRequest};
// Sample request data for benchmarks
fn create_sample_generate_request() -> GenerateRequest {
GenerateRequest {
text: Some("Write a story about artificial intelligence".to_string()),
input_ids: None,
prompt: None,
parameters: Some(GenerateParameters {
max_new_tokens: Some(100),
temperature: Some(0.8),
top_p: Some(0.9),
top_k: Some(50),
repetition_penalty: Some(1.0),
..Default::default()
}),
sampling_params: Some(SamplingParams {
temperature: Some(0.8),
top_p: Some(0.9),
top_k: Some(50),
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
repetition_penalty: Some(1.0),
..Default::default()
}),
stream: false,
return_logprob: false,
}
}
fn create_sample_chat_completion_request() -> ChatCompletionRequest {
ChatCompletionRequest {
model: "gpt-3.5-turbo".to_string(),
messages: vec![
ChatMessage::System {
role: "system".to_string(),
content: "You are a helpful assistant".to_string(),
name: None,
},
ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text(
"Explain quantum computing in simple terms".to_string(),
),
name: None,
},
],
max_tokens: Some(150),
max_completion_tokens: Some(150),
temperature: Some(0.7),
top_p: Some(1.0),
n: Some(1),
stream: false,
stop: None,
presence_penalty: Some(0.0),
frequency_penalty: Some(0.0),
logit_bias: None,
logprobs: false,
top_logprobs: None,
user: None,
response_format: None,
seed: None,
tools: None,
tool_choice: None,
parallel_tool_calls: Some(true),
function_call: None,
functions: None,
}
}
fn create_sample_completion_request() -> CompletionRequest {
CompletionRequest {
model: "text-davinci-003".to_string(),
prompt: StringOrArray::String("Complete this sentence: The future of AI is".to_string()),
suffix: None,
max_tokens: Some(50),
temperature: Some(0.8),
top_p: Some(1.0),
n: Some(1),
stream: false,
logprobs: None,
echo: false,
stop: None,
presence_penalty: Some(0.0),
frequency_penalty: Some(0.0),
best_of: Some(1),
logit_bias: None,
user: None,
seed: None,
}
}
fn create_large_chat_completion_request() -> ChatCompletionRequest {
let mut messages = vec![ChatMessage::System {
role: "system".to_string(),
content: "You are a helpful assistant with extensive knowledge.".to_string(),
name: None,
}];
// Add many user/assistant pairs to simulate a long conversation
for i in 0..50 {
messages.push(ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text(format!("Question {}: What do you think about topic number {} which involves complex reasoning about multiple interconnected systems and their relationships?", i, i)),
name: None,
});
messages.push(ChatMessage::Assistant {
role: "assistant".to_string(),
content: Some(format!("Answer {}: This is a detailed response about topic {} that covers multiple aspects and provides comprehensive analysis of the interconnected systems you mentioned.", i, i)),
name: None,
tool_calls: None,
function_call: None,
});
}
ChatCompletionRequest {
model: "gpt-4".to_string(),
messages,
max_tokens: Some(1000),
max_completion_tokens: Some(1000),
temperature: Some(0.7),
top_p: Some(0.95),
n: Some(1),
stream: false,
stop: None,
presence_penalty: Some(0.1),
frequency_penalty: Some(0.1),
logit_bias: None,
logprobs: false,
top_logprobs: Some(5),
user: Some("benchmark_user".to_string()),
response_format: None,
seed: Some(42),
tools: None,
tool_choice: None,
parallel_tool_calls: Some(true),
function_call: None,
functions: None,
}
}
// Benchmark JSON serialization
fn bench_json_serialization(c: &mut Criterion) {
let mut group = c.benchmark_group("json_serialization");
let generate_req = create_sample_generate_request();
let chat_req = create_sample_chat_completion_request();
let completion_req = create_sample_completion_request();
let large_chat_req = create_large_chat_completion_request();
group.bench_function("generate_request", |b| {
b.iter(|| {
let json = to_string(black_box(&generate_req)).unwrap();
black_box(json);
});
});
group.bench_function("chat_completion_request", |b| {
b.iter(|| {
let json = to_string(black_box(&chat_req)).unwrap();
black_box(json);
});
});
group.bench_function("completion_request", |b| {
b.iter(|| {
let json = to_string(black_box(&completion_req)).unwrap();
black_box(json);
});
});
group.bench_function("large_chat_completion_request", |b| {
b.iter(|| {
let json = to_string(black_box(&large_chat_req)).unwrap();
black_box(json);
});
});
group.bench_function("generate_request_to_bytes", |b| {
b.iter(|| {
let bytes = to_vec(black_box(&generate_req)).unwrap();
black_box(bytes);
});
});
group.finish();
}
// Benchmark JSON deserialization
fn bench_json_deserialization(c: &mut Criterion) {
let mut group = c.benchmark_group("json_deserialization");
let generate_json = to_string(&create_sample_generate_request()).unwrap();
let chat_json = to_string(&create_sample_chat_completion_request()).unwrap();
let completion_json = to_string(&create_sample_completion_request()).unwrap();
let large_chat_json = to_string(&create_large_chat_completion_request()).unwrap();
group.bench_function("generate_request", |b| {
b.iter(|| {
let req: GenerateRequest = from_str(black_box(&generate_json)).unwrap();
black_box(req);
});
});
group.bench_function("chat_completion_request", |b| {
b.iter(|| {
let req: ChatCompletionRequest = from_str(black_box(&chat_json)).unwrap();
black_box(req);
});
});
group.bench_function("completion_request", |b| {
b.iter(|| {
let req: CompletionRequest = from_str(black_box(&completion_json)).unwrap();
black_box(req);
});
});
group.bench_function("large_chat_completion_request", |b| {
b.iter(|| {
let req: ChatCompletionRequest = from_str(black_box(&large_chat_json)).unwrap();
black_box(req);
});
});
group.finish();
}
// Benchmark request adaptation from OpenAI to PD format
fn bench_request_adaptation(c: &mut Criterion) {
let mut group = c.benchmark_group("request_adaptation");
let generate_req = create_sample_generate_request();
let chat_req = create_sample_chat_completion_request();
let completion_req = create_sample_completion_request();
let large_chat_req = create_large_chat_completion_request();
group.bench_function("generate_to_pd", |b| {
b.iter(|| {
let pd_req = black_box(generate_req.clone()).to_pd_request();
black_box(pd_req);
});
});
group.bench_function("chat_completion_to_pd", |b| {
b.iter(|| {
let pd_req = black_box(chat_req.clone()).to_pd_request();
black_box(pd_req);
});
});
group.bench_function("completion_to_pd", |b| {
b.iter(|| {
let pd_req = black_box(completion_req.clone()).to_pd_request();
black_box(pd_req);
});
});
group.bench_function("large_chat_completion_to_pd", |b| {
b.iter(|| {
let pd_req = black_box(large_chat_req.clone()).to_pd_request();
black_box(pd_req);
});
});
group.finish();
}
// Benchmark regular routing (RouteableRequest methods)
fn bench_regular_routing(c: &mut Criterion) {
let mut group = c.benchmark_group("regular_routing");
let generate_req = create_sample_generate_request();
let chat_req = create_sample_chat_completion_request();
let completion_req = create_sample_completion_request();
group.bench_function("generate_to_json", |b| {
b.iter(|| {
let json = black_box(&generate_req).to_json().unwrap();
black_box(json);
});
});
group.bench_function("generate_to_bytes", |b| {
b.iter(|| {
let bytes = black_box(&generate_req).to_bytes().unwrap();
black_box(bytes);
});
});
group.bench_function("chat_completion_to_json", |b| {
b.iter(|| {
let json = black_box(&chat_req).to_json().unwrap();
black_box(json);
});
});
group.bench_function("chat_completion_to_bytes", |b| {
b.iter(|| {
let bytes = black_box(&chat_req).to_bytes().unwrap();
black_box(bytes);
});
});
group.bench_function("completion_to_json", |b| {
b.iter(|| {
let json = black_box(&completion_req).to_json().unwrap();
black_box(json);
});
});
group.finish();
}
// Benchmark throughput with different request sizes
fn bench_throughput_by_size(c: &mut Criterion) {
let mut group = c.benchmark_group("throughput_by_size");
// Create requests of different sizes
let small_generate = GenerateRequest {
text: Some("Hi".to_string()),
input_ids: None,
prompt: None,
parameters: None,
sampling_params: None,
stream: false,
return_logprob: false,
};
let medium_generate = GenerateRequest {
text: Some("Write a medium length story about AI".repeat(10)),
input_ids: None,
prompt: None,
parameters: None,
sampling_params: None,
stream: false,
return_logprob: false,
};
let large_generate = GenerateRequest {
text: Some("Write a very long and detailed story about artificial intelligence and its impact on society".repeat(100)),
input_ids: None,
prompt: None,
parameters: None,
sampling_params: None,
stream: false,
return_logprob: false,
};
for (name, req) in [
("small", &small_generate),
("medium", &medium_generate),
("large", &large_generate),
] {
let json = to_string(req).unwrap();
let size_bytes = json.len();
group.throughput(Throughput::Bytes(size_bytes as u64));
group.bench_with_input(BenchmarkId::new("serialize", name), &req, |b, req| {
b.iter(|| {
let json = to_string(black_box(req)).unwrap();
black_box(json);
});
});
group.bench_with_input(
BenchmarkId::new("deserialize", name),
&json,
|b, json_str| {
b.iter(|| {
let req: GenerateRequest = black_box(from_str(json_str)).unwrap();
black_box(req);
});
},
);
group.bench_with_input(BenchmarkId::new("adapt_to_pd", name), &req, |b, req| {
b.iter(|| {
let pd_req = (*req).clone().to_pd_request();
black_box(pd_req);
});
});
}
group.finish();
}
// Benchmark full round-trip: deserialize -> adapt -> serialize
fn bench_full_round_trip(c: &mut Criterion) {
let mut group = c.benchmark_group("full_round_trip");
let generate_json = to_string(&create_sample_generate_request()).unwrap();
let chat_json = to_string(&create_sample_chat_completion_request()).unwrap();
let completion_json = to_string(&create_sample_completion_request()).unwrap();
group.bench_function("generate_openai_to_pd_pipeline", |b| {
b.iter(|| {
// Deserialize OpenAI request
let req: GenerateRequest = from_str(black_box(&generate_json)).unwrap();
// Adapt to PD format
let pd_req = req.to_pd_request();
// Serialize PD request
let pd_json = to_string(&pd_req).unwrap();
black_box(pd_json);
});
});
group.bench_function("chat_completion_openai_to_pd_pipeline", |b| {
b.iter(|| {
let req: ChatCompletionRequest = from_str(black_box(&chat_json)).unwrap();
let pd_req = req.to_pd_request();
let pd_json = to_string(&pd_req).unwrap();
black_box(pd_json);
});
});
group.bench_function("completion_openai_to_pd_pipeline", |b| {
b.iter(|| {
let req: CompletionRequest = from_str(black_box(&completion_json)).unwrap();
let pd_req = req.to_pd_request();
let pd_json = to_string(&pd_req).unwrap();
black_box(pd_json);
});
});
group.bench_function("generate_regular_routing_pipeline", |b| {
b.iter(|| {
// Deserialize OpenAI request
let req: GenerateRequest = from_str(black_box(&generate_json)).unwrap();
// Convert to JSON for regular routing
let routing_json = req.to_json().unwrap();
black_box(routing_json);
});
});
group.finish();
}
fn benchmark_summary(c: &mut Criterion) {
let group = c.benchmark_group("benchmark_summary");
println!("\nSGLang Router Performance Benchmark Suite");
println!("=============================================");
// Quick performance overview
let generate_req = create_sample_generate_request();
println!("\nQuick Performance Overview:");
// Measure serialization
let start = Instant::now();
for _ in 0..1000 {
let _ = black_box(to_string(&generate_req).unwrap());
}
let serialize_time = start.elapsed().as_nanos() / 1000;
println!(" * Serialization (avg): {:>8} ns/req", serialize_time);
// Measure deserialization
let json = to_string(&generate_req).unwrap();
let start = Instant::now();
for _ in 0..1000 {
let _: GenerateRequest = black_box(from_str(&json).unwrap());
}
let deserialize_time = start.elapsed().as_nanos() / 1000;
println!(
" * Deserialization (avg): {:>8} ns/req",
deserialize_time
);
// Measure adaptation
let start = Instant::now();
for _ in 0..1000 {
let _ = black_box(generate_req.clone().to_pd_request());
}
let adapt_time = start.elapsed().as_nanos() / 1000;
println!(" * PD Adaptation (avg): {:>8} ns/req", adapt_time);
// Calculate ratios
let total_pipeline = serialize_time + deserialize_time + adapt_time;
println!(" * Total Pipeline (avg): {:>8} ns/req", total_pipeline);
println!("\nPerformance Insights:");
if deserialize_time > serialize_time * 2 {
println!(" • Deserialization is significantly faster than serialization");
}
if adapt_time < serialize_time / 10 {
println!(
" • PD adaptation overhead is negligible ({:.1}% of serialization)",
(adapt_time as f64 / serialize_time as f64) * 100.0
);
}
if total_pipeline < 10_000 {
println!(" • Total pipeline latency is excellent (< 10μs)");
}
println!("\nRecommendations:");
if serialize_time > deserialize_time {
println!(" • Focus optimization efforts on serialization rather than deserialization");
}
println!(" • PD mode overhead is minimal - safe to use for latency-sensitive workloads");
println!(" • Consider batching small requests to improve overall throughput");
println!("\n{}", "=".repeat(50));
group.finish();
}
criterion_group!(
benches,
benchmark_summary,
bench_json_serialization,
bench_json_deserialization,
bench_request_adaptation,
bench_regular_routing,
bench_throughput_by_size,
bench_full_round_trip
);
criterion_main!(benches);
......@@ -16,6 +16,11 @@ classifiers = [
"Programming Language :: Python :: 3",
]
[project.optional-dependencies]
dev = [
"requests>=2.25.0",
]
# https://github.com/PyO3/setuptools-rust?tab=readme-ov-file
[tool.setuptools.packages]
find = { where = ["py_src"] }
......
#!/usr/bin/env python3
"""
GitHub PR Comment Poster for Benchmark Results
Posts benchmark results as comments on GitHub PRs with update capability.
Replaces JavaScript logic in GitHub Actions for better maintainability.
"""
import argparse
import os
import sys
from pathlib import Path
from typing import Dict, Optional
import requests
class GitHubCommentPoster:
"""Handles posting benchmark results as GitHub PR comments."""
def __init__(self, token: str, repo_owner: str, repo_name: str):
self.token = token
self.repo_owner = repo_owner
self.repo_name = repo_name
self.base_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}"
self.headers = {
"Authorization": f"token {token}",
"Accept": "application/vnd.github.v3+json",
}
def read_benchmark_results(self, results_file: str) -> Dict[str, str]:
"""Read benchmark results from file."""
results = {}
filepath = Path(results_file)
if not filepath.exists():
print(f"Results file not found: {filepath}")
return {"error": "Results file not found"}
try:
with open(filepath, "r") as f:
for line in f:
line = line.strip()
if "=" in line:
key, value = line.split("=", 1)
results[key] = value
except Exception as e:
print(f"Error reading results file: {e}")
return {"error": str(e)}
return results
def format_benchmark_comment(
self, results: Dict[str, str], pr_number: int, commit_sha: str
) -> str:
"""Format benchmark results into a GitHub comment."""
serialization_time = results.get("serialization_time", "N/A")
deserialization_time = results.get("deserialization_time", "N/A")
adaptation_time = results.get("adaptation_time", "N/A")
total_time = results.get("total_time", "N/A")
comment = f"""
### SGLang Router Benchmark Results
**Performance Summary for PR #{pr_number}**
The router benchmarks have completed successfully!
**Performance Thresholds:** All passed
- Serialization: < 2μs
- Deserialization: < 2μs
- PD Adaptation: < 5μs
- Total Pipeline: < 10μs
**Measured Results:**
- Serialization: `{serialization_time}`ns
- Deserialization: `{deserialization_time}`ns
- PD Adaptation: `{adaptation_time}`ns
- Total Pipeline: `{total_time}`ns
**Detailed Reports:**
- Download the `benchmark-results-{commit_sha}` artifact for HTML reports
- Run `make bench` locally for detailed analysis
**Commit:** {commit_sha}
""".strip()
return comment
def find_existing_comment(self, pr_number: int) -> Optional[int]:
"""Find existing benchmark comment in the PR."""
url = f"{self.base_url}/issues/{pr_number}/comments"
try:
response = requests.get(url, headers=self.headers)
response.raise_for_status()
comments = response.json()
for comment in comments:
if comment.get("user", {}).get(
"login"
) == "github-actions[bot]" and "SGLang Router Benchmark Results" in comment.get(
"body", ""
):
return comment["id"]
except requests.RequestException as e:
print(f"Error fetching comments: {e}")
return None
def post_comment(self, pr_number: int, comment_body: str) -> bool:
"""Post a new comment on the PR."""
url = f"{self.base_url}/issues/{pr_number}/comments"
data = {"body": comment_body}
try:
response = requests.post(url, headers=self.headers, json=data)
response.raise_for_status()
print(f"Posted new benchmark comment on PR #{pr_number}")
return True
except requests.RequestException as e:
print(f"Error posting comment: {e}")
return False
def update_comment(self, comment_id: int, comment_body: str) -> bool:
"""Update an existing comment."""
url = f"{self.base_url}/issues/comments/{comment_id}"
data = {"body": comment_body}
try:
response = requests.patch(url, headers=self.headers, json=data)
response.raise_for_status()
print(f"Updated existing benchmark comment (ID: {comment_id})")
return True
except requests.RequestException as e:
print(f"Error updating comment: {e}")
return False
def post_or_update_comment(
self, pr_number: int, results_file: str, commit_sha: str
) -> bool:
"""Post or update benchmark results comment on PR."""
# Read benchmark results
results = self.read_benchmark_results(results_file)
if "error" in results:
print(f"Failed to read benchmark results: {results['error']}")
return False
# Format comment
comment_body = self.format_benchmark_comment(results, pr_number, commit_sha)
# Check for existing comment
existing_comment_id = self.find_existing_comment(pr_number)
if existing_comment_id:
return self.update_comment(existing_comment_id, comment_body)
else:
return self.post_comment(pr_number, comment_body)
def main():
parser = argparse.ArgumentParser(description="Post benchmark results to GitHub PR")
parser.add_argument(
"--pr-number", type=int, required=True, help="Pull request number"
)
parser.add_argument("--commit-sha", type=str, required=True, help="Commit SHA")
parser.add_argument(
"--results-file",
type=str,
default="benchmark_results.env",
help="Path to benchmark results file",
)
parser.add_argument(
"--repo-owner", type=str, default="sgl-project", help="GitHub repository owner"
)
parser.add_argument(
"--repo-name", type=str, default="sglang", help="GitHub repository name"
)
args = parser.parse_args()
# Get GitHub token from environment
github_token = os.environ.get("GITHUB_TOKEN")
if not github_token:
print("Error: GITHUB_TOKEN environment variable is required")
sys.exit(1)
# Create poster and post comment
poster = GitHubCommentPoster(github_token, args.repo_owner, args.repo_name)
success = poster.post_or_update_comment(
args.pr_number, args.results_file, args.commit_sha
)
if not success:
print("Failed to post benchmark comment")
sys.exit(1)
print("Benchmark comment posted successfully!")
if __name__ == "__main__":
main()
#!/usr/bin/env python3
"""
SGLang Router Benchmark Runner
A Python script to run Rust benchmarks with various options and modes.
Replaces the shell script for better maintainability and integration.
"""
import argparse
import os
import subprocess
import sys
import time
from pathlib import Path
from typing import Dict, List, Optional
class BenchmarkRunner:
"""Handles running Rust benchmarks for the SGLang router."""
def __init__(self, project_root: str):
self.project_root = Path(project_root)
self.timestamp = time.strftime("%a %b %d %H:%M:%S UTC %Y", time.gmtime())
def run_command(
self, cmd: List[str], capture_output: bool = False
) -> subprocess.CompletedProcess:
"""Run a command and handle errors."""
try:
if capture_output:
result = subprocess.run(
cmd, capture_output=True, text=True, cwd=self.project_root
)
else:
result = subprocess.run(cmd, cwd=self.project_root)
return result
except subprocess.CalledProcessError as e:
print(f"Error running command: {' '.join(cmd)}")
print(f"Exit code: {e.returncode}")
sys.exit(1)
def print_header(self):
"""Print the benchmark runner header."""
print("SGLang Router Benchmark Runner")
print("=" * 30)
print(f"Project: {self.project_root.absolute()}")
print(f"Timestamp: {self.timestamp}")
print()
def build_release(self):
"""Build the project in release mode."""
print("Building in release mode...")
result = self.run_command(["cargo", "build", "--release", "--quiet"])
if result.returncode != 0:
print("Failed to build in release mode")
sys.exit(1)
def run_benchmarks(
self,
quick_mode: bool = False,
save_baseline: Optional[str] = None,
compare_baseline: Optional[str] = None,
) -> str:
"""Run benchmarks with specified options."""
bench_args = ["cargo", "bench", "--bench", "request_processing"]
if quick_mode:
bench_args.append("benchmark_summary")
print("Running quick benchmarks...")
else:
print("Running full benchmark suite...")
# Note: Criterion baselines are handled via target directory structure
# For now, we'll implement baseline functionality via file copying
if save_baseline:
print(f"Will save results as baseline: {save_baseline}")
if compare_baseline:
print(f"Will compare with baseline: {compare_baseline}")
print(f"Executing: {' '.join(bench_args)}")
result = self.run_command(bench_args, capture_output=True)
if result.returncode != 0:
print("Benchmark execution failed!")
print("STDOUT:", result.stdout)
print("STDERR:", result.stderr)
sys.exit(1)
# Handle baseline saving after successful run
if save_baseline:
self._save_baseline(save_baseline, result.stdout)
return result.stdout
def _save_baseline(self, filename: str, output: str):
"""Save benchmark results to a file as baseline."""
filepath = self.project_root / filename
with open(filepath, "w") as f:
f.write(output)
print(f"Baseline saved to: {filepath}")
def parse_benchmark_results(self, output: str) -> Dict[str, str]:
"""Parse benchmark output to extract performance metrics."""
results = {}
# Look for performance overview section
lines = output.split("\n")
parsing_overview = False
for line in lines:
line = line.strip()
if "Quick Performance Overview:" in line:
parsing_overview = True
continue
if parsing_overview and line.startswith("* "):
# Parse lines like "* Serialization (avg): 481 ns/req"
if "Serialization (avg):" in line:
results["serialization_time"] = self._extract_time(line)
elif "Deserialization (avg):" in line:
results["deserialization_time"] = self._extract_time(line)
elif "PD Adaptation (avg):" in line:
results["adaptation_time"] = self._extract_time(line)
elif "Total Pipeline (avg):" in line:
results["total_time"] = self._extract_time(line)
# Stop parsing after the overview section
if parsing_overview and line.startswith("Performance Insights:"):
break
return results
def _extract_time(self, line: str) -> str:
"""Extract time value from a benchmark line."""
# Extract number followed by ns/req
import re
match = re.search(r"(\d+)\s*ns/req", line)
return match.group(1) if match else "N/A"
def validate_thresholds(self, results: Dict[str, str]) -> bool:
"""Validate benchmark results against performance thresholds."""
thresholds = {
"serialization_time": 2000, # 2μs max
"deserialization_time": 2000, # 2μs max
"adaptation_time": 5000, # 5μs max
"total_time": 10000, # 10μs max
}
all_passed = True
print("\nPerformance Threshold Validation:")
print("=" * 35)
for metric, threshold in thresholds.items():
if metric in results and results[metric] != "N/A":
try:
value = int(results[metric])
passed = value <= threshold
status = "✓ PASS" if passed else "✗ FAIL"
print(f"{metric:20}: {value:>6}ns <= {threshold:>6}ns {status}")
if not passed:
all_passed = False
except ValueError:
print(f"{metric:20}: Invalid value: {results[metric]}")
all_passed = False
else:
print(f"{metric:20}: No data available")
all_passed = False
print()
if all_passed:
print("All performance thresholds passed!")
else:
print("Some performance thresholds failed!")
return all_passed
def save_results_to_file(
self, results: Dict[str, str], filename: str = "benchmark_results.env"
):
"""Save benchmark results to a file for CI consumption."""
filepath = self.project_root / filename
with open(filepath, "w") as f:
for key, value in results.items():
f.write(f"{key}={value}\n")
print(f"Results saved to: {filepath}")
def main():
parser = argparse.ArgumentParser(description="Run SGLang router benchmarks")
parser.add_argument(
"--quick", action="store_true", help="Run quick benchmarks (summary only)"
)
parser.add_argument(
"--save-baseline", type=str, help="Save benchmark results as baseline"
)
parser.add_argument(
"--compare-baseline", type=str, help="Compare with saved baseline"
)
parser.add_argument(
"--validate-thresholds",
action="store_true",
help="Validate results against performance thresholds",
)
parser.add_argument(
"--save-results", action="store_true", help="Save results to file for CI"
)
args = parser.parse_args()
# Determine project root (script is in scripts/ subdirectory)
script_dir = Path(__file__).parent
project_root = script_dir.parent
runner = BenchmarkRunner(str(project_root))
runner.print_header()
# Build in release mode
runner.build_release()
# Run benchmarks
output = runner.run_benchmarks(
quick_mode=args.quick,
save_baseline=args.save_baseline,
compare_baseline=args.compare_baseline,
)
# Print the raw output
print(output)
# Parse and validate results if requested
if args.validate_thresholds or args.save_results:
results = runner.parse_benchmark_results(output)
if args.save_results:
runner.save_results_to_file(results)
if args.validate_thresholds:
passed = runner.validate_thresholds(results)
if not passed:
print("Validation failed - performance regression detected!")
sys.exit(1)
print("\nBenchmark run completed successfully!")
if __name__ == "__main__":
main()
// Integration test to ensure benchmarks compile and basic functionality works
// This prevents benchmarks from breaking in CI
use serde_json::{from_str, to_string};
use sglang_router_rs::openai_api_types::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
SamplingParams, StringOrArray, UserMessageContent,
};
use sglang_router_rs::request_adapter::{RouteableRequest, ToPdRequest};
#[test]
fn test_benchmark_request_creation() {
// Ensure all benchmark request types can be created without panicking
let generate_req = GenerateRequest {
text: Some("Test prompt".to_string()),
input_ids: None,
prompt: None,
parameters: Some(GenerateParameters {
max_new_tokens: Some(100),
temperature: Some(0.8),
top_p: Some(0.9),
top_k: Some(50),
repetition_penalty: Some(1.0),
..Default::default()
}),
sampling_params: Some(SamplingParams {
temperature: Some(0.8),
top_p: Some(0.9),
top_k: Some(50),
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
repetition_penalty: Some(1.0),
..Default::default()
}),
stream: false,
return_logprob: false,
};
let chat_req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Test message".to_string()),
name: None,
}],
max_tokens: Some(150),
max_completion_tokens: Some(150),
temperature: Some(0.7),
top_p: Some(1.0),
n: Some(1),
stream: false,
stop: None,
presence_penalty: Some(0.0),
frequency_penalty: Some(0.0),
logit_bias: None,
logprobs: false,
top_logprobs: None,
user: None,
response_format: None,
seed: None,
tools: None,
tool_choice: None,
parallel_tool_calls: Some(true),
function_call: None,
functions: None,
};
let completion_req = CompletionRequest {
model: "test-model".to_string(),
prompt: StringOrArray::String("Test prompt".to_string()),
suffix: None,
max_tokens: Some(50),
temperature: Some(0.8),
top_p: Some(1.0),
n: Some(1),
stream: false,
logprobs: None,
echo: false,
stop: None,
presence_penalty: Some(0.0),
frequency_penalty: Some(0.0),
best_of: Some(1),
logit_bias: None,
user: None,
seed: None,
};
// Test serialization works
assert!(to_string(&generate_req).is_ok());
assert!(to_string(&chat_req).is_ok());
assert!(to_string(&completion_req).is_ok());
}
#[test]
fn test_benchmark_serialization_roundtrip() {
// Test serialization/deserialization roundtrip for benchmark types
let generate_req = GenerateRequest {
text: Some("Test prompt".to_string()),
input_ids: None,
prompt: None,
parameters: None,
sampling_params: None,
stream: false,
return_logprob: false,
};
// Serialize and deserialize
let json = to_string(&generate_req).expect("Serialization should work");
let deserialized: GenerateRequest = from_str(&json).expect("Deserialization should work");
// Verify basic field equality
assert_eq!(generate_req.text, deserialized.text);
assert_eq!(generate_req.stream, deserialized.stream);
assert_eq!(generate_req.return_logprob, deserialized.return_logprob);
}
#[test]
fn test_benchmark_request_adaptation() {
// Test that PD request adaptation works for benchmark types
let generate_req = GenerateRequest {
text: Some("Test prompt".to_string()),
input_ids: None,
prompt: None,
parameters: None,
sampling_params: None,
stream: false,
return_logprob: false,
};
let chat_req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Test message".to_string()),
name: None,
}],
max_tokens: Some(150),
max_completion_tokens: Some(150),
temperature: Some(0.7),
top_p: Some(1.0),
n: Some(1),
stream: false,
stop: None,
presence_penalty: Some(0.0),
frequency_penalty: Some(0.0),
logit_bias: None,
logprobs: false,
top_logprobs: None,
user: None,
response_format: None,
seed: None,
tools: None,
tool_choice: None,
parallel_tool_calls: Some(true),
function_call: None,
functions: None,
};
let completion_req = CompletionRequest {
model: "test-model".to_string(),
prompt: StringOrArray::String("Test prompt".to_string()),
suffix: None,
max_tokens: Some(50),
temperature: Some(0.8),
top_p: Some(1.0),
n: Some(1),
stream: false,
logprobs: None,
echo: false,
stop: None,
presence_penalty: Some(0.0),
frequency_penalty: Some(0.0),
best_of: Some(1),
logit_bias: None,
user: None,
seed: None,
};
// Test PD adaptation (should not panic)
let _pd_generate = generate_req.to_pd_request();
let _pd_chat = chat_req.to_pd_request();
let _pd_completion = completion_req.to_pd_request();
}
#[test]
fn test_benchmark_regular_routing() {
// Test regular routing functionality for benchmark types
let generate_req = GenerateRequest {
text: Some("Test prompt".to_string()),
input_ids: None,
prompt: None,
parameters: None,
sampling_params: None,
stream: false,
return_logprob: false,
};
// Test regular routing methods (should not panic)
let _json = generate_req.to_json();
let _bytes = generate_req.to_bytes();
}
#[test]
fn test_benchmark_performance_baseline() {
// Basic performance sanity check - ensure operations complete quickly
use std::time::Instant;
let generate_req = GenerateRequest {
text: Some("Short test prompt".to_string()),
input_ids: None,
prompt: None,
parameters: None,
sampling_params: None,
stream: false,
return_logprob: false,
};
// Serialization should be fast (< 1ms for simple requests)
let start = Instant::now();
let _json = to_string(&generate_req).unwrap();
let serialize_duration = start.elapsed();
assert!(
serialize_duration.as_millis() < 1,
"Serialization took too long: {:?}",
serialize_duration
);
// PD adaptation should be very fast (< 1ms)
let start = Instant::now();
let _pd_req = generate_req.to_pd_request();
let adapt_duration = start.elapsed();
assert!(
adapt_duration.as_millis() < 1,
"PD adaptation took too long: {:?}",
adapt_duration
);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment