use std::time::Duration; use tonic::{transport::Channel, Request}; use tracing::debug; // Include the generated protobuf code pub mod proto { tonic::include_proto!("sglang.grpc.scheduler"); } // The generated module structure depends on the package name in the .proto file // package sglang.grpc.scheduler; generates a nested module structure /// gRPC client for SGLang scheduler pub struct SglangSchedulerClient { client: proto::sglang_scheduler_client::SglangSchedulerClient, } impl SglangSchedulerClient { /// Create a new client and connect to the scheduler pub async fn connect(endpoint: &str) -> Result> { debug!("Connecting to SGLang scheduler at {}", endpoint); // Convert grpc:// to http:// for tonic let http_endpoint = if endpoint.starts_with("grpc://") { endpoint.replace("grpc://", "http://") } else { endpoint.to_string() }; let channel = Channel::from_shared(http_endpoint)? .timeout(Duration::from_secs(30)) .connect() .await?; let client = proto::sglang_scheduler_client::SglangSchedulerClient::new(channel); Ok(Self { client }) } /// Submit a generation request (returns streaming response) pub async fn generate_stream( &mut self, req: proto::GenerateRequest, ) -> Result, Box> { let request = Request::new(req); let response = self.client.generate(request).await?; Ok(response.into_inner()) } /// Perform health check pub async fn health_check( &mut self, ) -> Result> { debug!("Sending health check request"); let request = Request::new(proto::HealthCheckRequest { tokenized: Some(proto::TokenizedInput { original_text: "Hello".to_string(), input_ids: vec![9906], // Mock token ID for "Hello" }), }); let response = self.client.health_check(request).await?; debug!("Health check response received"); Ok(response.into_inner()) } /// Abort a request pub async fn abort_request( &mut self, request_id: String, reason: String, ) -> Result<(), Box> { let request = Request::new(proto::AbortRequest { request_id, reason }); self.client.abort(request).await?; Ok(()) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_proto_types_compilation() { // Test that protobuf types can be constructed let health_req = proto::HealthCheckRequest { tokenized: Some(proto::TokenizedInput { original_text: "test".to_string(), input_ids: vec![1296], }), }; assert!(health_req.tokenized.is_some()); } #[test] fn test_generate_request_construction() { let sampling_params = proto::SamplingParams { temperature: 0.7, max_new_tokens: 128, top_p: 0.9, top_k: 50, stop: vec!["".to_string()], ..Default::default() }; let gen_req = proto::GenerateRequest { request_id: "test-req-123".to_string(), tokenized: Some(proto::TokenizedInput { original_text: "Hello world".to_string(), input_ids: vec![9906, 1917], // Mock token IDs for "Hello world" }), sampling_params: Some(sampling_params), return_logprob: true, logprob_start_len: 0, top_logprobs_num: 5, ..Default::default() }; assert_eq!(gen_req.request_id, "test-req-123"); if let Some(ref tokenized) = &gen_req.tokenized { assert_eq!(tokenized.original_text, "Hello world"); } assert!(gen_req.return_logprob); assert_eq!(gen_req.top_logprobs_num, 5); let params = gen_req.sampling_params.unwrap(); assert_eq!(params.temperature, 0.7); assert_eq!(params.max_new_tokens, 128); assert_eq!(params.stop, vec![""]); } #[test] fn test_health_check_request() { let health_req = proto::HealthCheckRequest { tokenized: Some(proto::TokenizedInput { original_text: "test".to_string(), input_ids: vec![1296], // Mock token ID for "test" }), }; assert!(health_req.tokenized.is_some()); } #[test] fn test_abort_request_construction() { let abort_req = proto::AbortRequest { request_id: "req-456".to_string(), reason: "User canceled".to_string(), }; assert_eq!(abort_req.request_id, "req-456"); assert_eq!(abort_req.reason, "User canceled"); } #[test] fn test_sampling_params_defaults() { let params = proto::SamplingParams::default(); assert_eq!(params.temperature, 0.0); assert_eq!(params.max_new_tokens, 0); assert_eq!(params.top_p, 0.0); assert_eq!(params.top_k, 0); assert!(params.stop.is_empty()); } #[test] fn test_multimodal_inputs() { let mm_inputs = proto::MultimodalInputs { image_urls: vec!["http://example.com/image.jpg".to_string()], video_urls: vec![], audio_urls: vec![], image_data: vec![], video_data: vec![], audio_data: vec![], modalities: vec!["image".to_string()], ..Default::default() }; assert_eq!(mm_inputs.image_urls.len(), 1); assert_eq!(mm_inputs.image_urls[0], "http://example.com/image.jpg"); assert_eq!(mm_inputs.modalities[0], "image"); } // TODO: SessionParams not in current proto - skip test // #[test] // fn test_session_params() { ... } #[test] fn test_embed_request() { let embed_req = proto::EmbedRequest { request_id: "embed-req-202".to_string(), tokenized: Some(proto::TokenizedInput { original_text: "This is a test sentence for embedding".to_string(), input_ids: vec![2028, 374, 264, 1296, 11914, 369, 28537], // Mock token IDs }), log_metrics: true, data_parallel_rank: 0, ..Default::default() }; assert_eq!(embed_req.request_id, "embed-req-202"); if let Some(ref tokenized) = &embed_req.tokenized { assert_eq!( tokenized.original_text, "This is a test sentence for embedding" ); } assert!(embed_req.log_metrics); assert_eq!(embed_req.data_parallel_rank, 0); } #[tokio::test] async fn test_client_connect_invalid_endpoint() { // Test connecting to an invalid endpoint should return error let result = SglangSchedulerClient::connect("invalid://endpoint").await; assert!(result.is_err()); } #[test] fn test_tokenized_input() { let tokenized = proto::TokenizedInput { original_text: "Hello world".to_string(), input_ids: vec![1, 15043, 1917, 2], }; assert_eq!(tokenized.original_text, "Hello world"); assert_eq!(tokenized.input_ids, vec![1, 15043, 1917, 2]); } // Test response type construction #[test] fn test_generate_stream_chunk() { let chunk = proto::GenerateStreamChunk { token_id: 1234, text: " world".to_string(), prompt_tokens: 5, completion_tokens: 2, cached_tokens: 3, generation_time: 0.025, queue_time: 10, ..Default::default() }; assert_eq!(chunk.token_id, 1234); assert_eq!(chunk.text, " world"); assert_eq!(chunk.prompt_tokens, 5); assert_eq!(chunk.completion_tokens, 2); assert_eq!(chunk.cached_tokens, 3); assert_eq!(chunk.generation_time, 0.025); assert_eq!(chunk.queue_time, 10); } // TODO: ModelInfo not in current proto - skip test // #[test] // fn test_model_info() { ... } }