Unverified Commit 7638f5e4 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router] Implement gRPC SGLangSchedulerClient (#9364)

parent b45f753c
...@@ -47,7 +47,7 @@ jobs: ...@@ -47,7 +47,7 @@ jobs:
env: env:
CIBW_BUILD: "cp38-manylinux_x86_64 cp39-manylinux_x86_64 cp310-manylinux_x86_64 cp311-manylinux_x86_64 cp312-manylinux_x86_64" CIBW_BUILD: "cp38-manylinux_x86_64 cp39-manylinux_x86_64 cp310-manylinux_x86_64 cp311-manylinux_x86_64 cp312-manylinux_x86_64"
CIBW_BEFORE_ALL: | CIBW_BEFORE_ALL: |
yum update && yum install -y openssl-devel && curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y yum update && yum install -y openssl-devel protobuf-compiler && curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
CIBW_ENVIRONMENT: "PATH=$HOME/.cargo/bin:$PATH" CIBW_ENVIRONMENT: "PATH=$HOME/.cargo/bin:$PATH"
- name: List built packages - name: List built packages
......
...@@ -39,13 +39,13 @@ ENV PATH="/root/.cargo/bin:${PATH}" ...@@ -39,13 +39,13 @@ ENV PATH="/root/.cargo/bin:${PATH}"
# install dependencies # install dependencies
RUN apt update -y \ RUN apt update -y \
&& apt install -y git build-essential libssl-dev pkg-config \ && apt install -y git build-essential libssl-dev pkg-config protobuf-compiler \
&& rm -rf /var/lib/apt/lists/* \ && rm -rf /var/lib/apt/lists/* \
&& apt clean && apt clean
# install rustup from rustup.rs # install rustup from rustup.rs
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y \ RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y \
&& rustc --version && cargo --version && rustc --version && cargo --version && protoc --version
# pull the github repository # pull the github repository
RUN cd /opt \ RUN cd /opt \
......
...@@ -4,10 +4,10 @@ set -euxo pipefail ...@@ -4,10 +4,10 @@ set -euxo pipefail
# Check if sudo is available # Check if sudo is available
if command -v sudo >/dev/null 2>&1; then if command -v sudo >/dev/null 2>&1; then
sudo apt-get update sudo apt-get update
sudo apt-get install -y libssl-dev pkg-config sudo apt-get install -y libssl-dev pkg-config protobuf-compiler
else else
apt-get update apt-get update
apt-get install -y libssl-dev pkg-config apt-get install -y libssl-dev pkg-config protobuf-compiler
fi fi
# Install rustup (Rust installer and version manager) # Install rustup (Rust installer and version manager)
...@@ -21,3 +21,4 @@ source $HOME/.cargo/env ...@@ -21,3 +21,4 @@ source $HOME/.cargo/env
# Verify installation # Verify installation
rustc --version rustc --version
cargo --version cargo --version
protoc --version
...@@ -4,9 +4,11 @@ version = "0.0.0" ...@@ -4,9 +4,11 @@ version = "0.0.0"
edition = "2021" edition = "2021"
[features] [features]
default = ["huggingface"] default = ["huggingface", "grpc-client"]
huggingface = ["tokenizers"] huggingface = ["tokenizers"]
tiktoken = ["tiktoken-rs"] tiktoken = ["tiktoken-rs"]
grpc-client = []
grpc-server = []
[lib] [lib]
name = "sglang_router_rs" name = "sglang_router_rs"
...@@ -52,6 +54,18 @@ anyhow = "1.0" ...@@ -52,6 +54,18 @@ anyhow = "1.0"
tokenizers = { version = "0.21.4", optional = true } tokenizers = { version = "0.21.4", optional = true }
tiktoken-rs = { version = "0.5", optional = true } tiktoken-rs = { version = "0.5", optional = true }
# gRPC and Protobuf dependencies
tonic = { version = "0.12", features = ["tls", "gzip", "transport"] }
prost = "0.13"
prost-types = "0.13"
deadpool = { version = "0.12", features = ["managed", "rt_tokio_1"] }
backoff = { version = "0.4", features = ["tokio"] }
strum = { version = "0.26", features = ["derive"] }
[build-dependencies]
tonic-build = "0.12"
prost-build = "0.13"
[dev-dependencies] [dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] } criterion = { version = "0.5", features = ["html_reports"] }
tower = { version = "0.5", features = ["util"] } tower = { version = "0.5", features = ["util"] }
......
# Must include: # Must include:
include Cargo.toml # Rust project configuration include Cargo.toml # Rust project configuration
include build.rs # Build script for protobuf generation
recursive-include src *.rs # Rust source files recursive-include src *.rs # Rust source files
recursive-include src/proto *.proto # Protobuf definitions
fn main() -> Result<(), Box<dyn std::error::Error>> {
// Only regenerate if the proto file changes
println!("cargo:rerun-if-changed=src/proto/sglang_scheduler.proto");
// Configure protobuf compilation with custom settings
let config = prost_build::Config::new();
// Skip serde for types that use prost_types::Struct
// These cause conflicts and we don't need serde for all generated types
// Configure tonic-build for gRPC code generation
tonic_build::configure()
// Generate both client and server code
.build_server(true)
.build_client(true)
// Add a module-level attribute for documentation and clippy warnings
.server_mod_attribute(
"sglang.grpc.scheduler",
"#[allow(unused, clippy::mixed_attributes_style)]",
)
.client_mod_attribute(
"sglang.grpc.scheduler",
"#[allow(unused, clippy::mixed_attributes_style)]",
)
// Compile the proto file with the custom config
.compile_protos_with_config(
config,
&["src/proto/sglang_scheduler.proto"],
&["src/proto"],
)?;
println!("cargo:warning=Protobuf compilation completed successfully");
Ok(())
}
use std::time::Duration;
use tonic::{transport::Channel, Request};
use tracing::debug;
// Include the generated protobuf code
pub mod proto {
tonic::include_proto!("sglang.grpc.scheduler");
}
// The generated module structure depends on the package name in the .proto file
// package sglang.grpc.scheduler; generates a nested module structure
/// gRPC client for SGLang scheduler
pub struct SglangSchedulerClient {
client: proto::sglang_scheduler_client::SglangSchedulerClient<Channel>,
}
impl SglangSchedulerClient {
/// Create a new client and connect to the scheduler
pub async fn connect(endpoint: &str) -> Result<Self, Box<dyn std::error::Error>> {
debug!("Connecting to SGLang scheduler at {}", endpoint);
let channel = Channel::from_shared(endpoint.to_string())?
.timeout(Duration::from_secs(30))
.connect()
.await?;
let client = proto::sglang_scheduler_client::SglangSchedulerClient::new(channel);
Ok(Self { client })
}
/// Initialize the connection
pub async fn initialize(
&mut self,
client_id: String,
) -> Result<proto::InitializeResponse, Box<dyn std::error::Error>> {
let request = Request::new(proto::InitializeRequest {
client_id,
client_version: "0.1.0".to_string(),
mode: proto::initialize_request::Mode::Regular as i32,
});
let response = self.client.initialize(request).await?;
Ok(response.into_inner())
}
/// Submit a generation request (returns streaming response)
pub async fn generate_stream(
&mut self,
req: proto::GenerateRequest,
) -> Result<tonic::Streaming<proto::GenerateResponse>, Box<dyn std::error::Error>> {
let request = Request::new(req);
let response = self.client.generate(request).await?;
Ok(response.into_inner())
}
/// Perform health check
pub async fn health_check(
&mut self,
) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error>> {
let request = Request::new(proto::HealthCheckRequest {
include_detailed_metrics: false,
});
let response = self.client.health_check(request).await?;
Ok(response.into_inner())
}
/// Abort a request
pub async fn abort_request(
&mut self,
request_id: String,
reason: String,
) -> Result<(), Box<dyn std::error::Error>> {
let request = Request::new(proto::AbortRequest { request_id, reason });
self.client.abort(request).await?;
Ok(())
}
/// Flush cache
pub async fn flush_cache(
&mut self,
flush_all: bool,
session_ids: &[String],
) -> Result<proto::FlushCacheResponse, Box<dyn std::error::Error>> {
let request = Request::new(proto::FlushCacheRequest {
flush_all,
session_ids: session_ids.to_vec(),
});
let response = self.client.flush_cache(request).await?;
Ok(response.into_inner())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_proto_types_compilation() {
// Test that protobuf types can be constructed
let init_req = proto::InitializeRequest {
client_id: "test-client".to_string(),
client_version: "0.1.0".to_string(),
mode: 0,
};
assert_eq!(init_req.client_id, "test-client");
assert_eq!(init_req.client_version, "0.1.0");
assert_eq!(init_req.mode, 0);
}
#[test]
fn test_generate_request_construction() {
let sampling_params = proto::SamplingParams {
temperature: 0.7,
max_new_tokens: 128,
top_p: 0.9,
top_k: 50,
stop: vec!["</s>".to_string()],
..Default::default()
};
let gen_req = proto::GenerateRequest {
request_id: "test-req-123".to_string(),
input: Some(proto::generate_request::Input::Text(
"Hello world".to_string(),
)),
sampling_params: Some(sampling_params),
return_logprob: true,
logprob_start_len: 0,
top_logprobs_num: 5,
..Default::default()
};
assert_eq!(gen_req.request_id, "test-req-123");
if let Some(proto::generate_request::Input::Text(text)) = &gen_req.input {
assert_eq!(text, "Hello world");
}
assert!(gen_req.return_logprob);
assert_eq!(gen_req.top_logprobs_num, 5);
let params = gen_req.sampling_params.unwrap();
assert_eq!(params.temperature, 0.7);
assert_eq!(params.max_new_tokens, 128);
assert_eq!(params.stop, vec!["</s>"]);
}
#[test]
fn test_health_check_request() {
let health_req = proto::HealthCheckRequest {
include_detailed_metrics: true,
};
assert!(health_req.include_detailed_metrics);
}
#[test]
fn test_abort_request_construction() {
let abort_req = proto::AbortRequest {
request_id: "req-456".to_string(),
reason: "User canceled".to_string(),
};
assert_eq!(abort_req.request_id, "req-456");
assert_eq!(abort_req.reason, "User canceled");
}
#[test]
fn test_flush_cache_request() {
let flush_req = proto::FlushCacheRequest {
flush_all: true,
session_ids: vec!["session1".to_string(), "session2".to_string()],
};
assert!(flush_req.flush_all);
assert_eq!(flush_req.session_ids.len(), 2);
assert_eq!(flush_req.session_ids[0], "session1");
}
#[test]
fn test_sampling_params_defaults() {
let params = proto::SamplingParams::default();
assert_eq!(params.temperature, 0.0);
assert_eq!(params.max_new_tokens, 0);
assert_eq!(params.top_p, 0.0);
assert_eq!(params.top_k, 0);
assert!(params.stop.is_empty());
}
#[test]
fn test_multimodal_inputs() {
let mm_inputs = proto::MultimodalInputs {
image_urls: vec!["http://example.com/image.jpg".to_string()],
video_urls: vec![],
audio_urls: vec![],
image_data: vec![],
video_data: vec![],
audio_data: vec![],
modalities: vec!["image".to_string()],
..Default::default()
};
assert_eq!(mm_inputs.image_urls.len(), 1);
assert_eq!(mm_inputs.image_urls[0], "http://example.com/image.jpg");
assert_eq!(mm_inputs.modalities[0], "image");
}
#[test]
fn test_session_params() {
let session_params = proto::SessionParams {
session_id: "sess-789".to_string(),
request_id: "req-101".to_string(),
offset: 100,
replace: true,
drop_previous_output: false,
};
assert_eq!(session_params.session_id, "sess-789");
assert_eq!(session_params.request_id, "req-101");
assert_eq!(session_params.offset, 100);
assert!(session_params.replace);
assert!(!session_params.drop_previous_output);
}
#[test]
fn test_embed_request() {
let embed_req = proto::EmbedRequest {
request_id: "embed-req-202".to_string(),
input: Some(proto::embed_request::Input::Text(
"This is a test sentence for embedding".to_string(),
)),
log_metrics: true,
data_parallel_rank: 0,
..Default::default()
};
assert_eq!(embed_req.request_id, "embed-req-202");
if let Some(proto::embed_request::Input::Text(text)) = &embed_req.input {
assert_eq!(text, "This is a test sentence for embedding");
}
assert!(embed_req.log_metrics);
assert_eq!(embed_req.data_parallel_rank, 0);
}
#[tokio::test]
async fn test_client_connect_invalid_endpoint() {
// Test connecting to an invalid endpoint should return error
let result = SglangSchedulerClient::connect("invalid://endpoint").await;
assert!(result.is_err());
}
#[test]
fn test_tokenized_input() {
let tokenized = proto::TokenizedInput {
original_text: "Hello world".to_string(),
input_ids: vec![1, 15043, 1917, 2],
};
assert_eq!(tokenized.original_text, "Hello world");
assert_eq!(tokenized.input_ids, vec![1, 15043, 1917, 2]);
}
// Test response type construction
#[test]
fn test_generate_stream_chunk() {
let chunk = proto::GenerateStreamChunk {
token_id: 1234,
text: " world".to_string(),
prompt_tokens: 5,
completion_tokens: 2,
cached_tokens: 3,
generation_time: 0.025,
queue_time: 10,
..Default::default()
};
assert_eq!(chunk.token_id, 1234);
assert_eq!(chunk.text, " world");
assert_eq!(chunk.prompt_tokens, 5);
assert_eq!(chunk.completion_tokens, 2);
assert_eq!(chunk.cached_tokens, 3);
assert_eq!(chunk.generation_time, 0.025);
assert_eq!(chunk.queue_time, 10);
}
#[test]
fn test_model_info() {
let model_info = proto::ModelInfo {
model_name: "Meta-Llama-3-8B-Instruct".to_string(),
max_context_length: 8192,
vocab_size: 128256,
supports_tool_calling: true,
supports_vision: false,
special_tokens: vec![
"<|begin_of_text|>".to_string(),
"<|end_of_text|>".to_string(),
],
model_type: "llama".to_string(),
num_layers: 32,
hidden_size: 4096,
num_attention_heads: 32,
num_key_value_heads: 8,
tokenizer_type: "llama".to_string(),
eos_token_ids: vec![128001, 128009],
pad_token_id: 128001,
bos_token_id: 128000,
};
assert_eq!(model_info.model_name, "Meta-Llama-3-8B-Instruct");
assert_eq!(model_info.max_context_length, 8192);
assert_eq!(model_info.vocab_size, 128256);
assert!(model_info.supports_tool_calling);
assert!(!model_info.supports_vision);
assert_eq!(model_info.special_tokens.len(), 2);
assert_eq!(model_info.num_layers, 32);
assert_eq!(model_info.eos_token_ids, vec![128001, 128009]);
}
}
//! gRPC client module for communicating with SGLang scheduler
//!
//! This module provides a gRPC client implementation for the SGLang router.
pub mod client;
// Re-export the client
pub use client::{proto, SglangSchedulerClient};
...@@ -3,6 +3,8 @@ pub mod config; ...@@ -3,6 +3,8 @@ pub mod config;
pub mod logging; pub mod logging;
use std::collections::HashMap; use std::collections::HashMap;
pub mod core; pub mod core;
#[cfg(feature = "grpc-client")]
pub mod grpc;
pub mod metrics; pub mod metrics;
pub mod middleware; pub mod middleware;
pub mod policies; pub mod policies;
......
...@@ -7,7 +7,7 @@ import "google/protobuf/struct.proto"; ...@@ -7,7 +7,7 @@ import "google/protobuf/struct.proto";
// Service definition for SGLang scheduler communication // Service definition for SGLang scheduler communication
// This protocol bridges the Rust router and Python scheduler // This protocol bridges the Rust router and Python scheduler
service SGLangScheduler { service SglangScheduler {
// Initialize connection and get model info // Initialize connection and get model info
rpc Initialize(InitializeRequest) returns (InitializeResponse); rpc Initialize(InitializeRequest) returns (InitializeResponse);
...@@ -21,7 +21,7 @@ service SGLangScheduler { ...@@ -21,7 +21,7 @@ service SGLangScheduler {
rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse); rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse);
// Abort a running request // Abort a running request
rpc AbortRequest(AbortRequest) returns (AbortResponse); rpc Abort(AbortRequest) returns (AbortResponse);
// Flush KV cache // Flush KV cache
rpc FlushCache(FlushCacheRequest) returns (FlushCacheResponse); rpc FlushCache(FlushCacheRequest) returns (FlushCacheResponse);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment