use std::sync::Arc; use async_trait::async_trait; use axum::{ body::Body, extract::Request, http::{HeaderMap, StatusCode}, response::{IntoResponse, Response}, }; use tracing::debug; use super::{context::SharedComponents, pipeline::RequestPipeline}; use crate::{ app_context::AppContext, core::{ConnectionMode, WorkerRegistry, WorkerType}, protocols::{ chat::ChatCompletionRequest, classify::ClassifyRequest, completion::CompletionRequest, embedding::EmbeddingRequest, generate::GenerateRequest, rerank::RerankRequest, responses::{ResponsesGetParams, ResponsesRequest}, }, routers::RouterTrait, }; /// gRPC PD (Prefill-Decode) router implementation for SGLang #[derive(Clone)] pub struct GrpcPDRouter { worker_registry: Arc, pipeline: RequestPipeline, shared_components: Arc, } impl GrpcPDRouter { /// Create a new gRPC PD router pub async fn new(ctx: &Arc) -> Result { // Get registries from context let worker_registry = ctx.worker_registry.clone(); let policy_registry = ctx.policy_registry.clone(); // Extract necessary components from context let tokenizer = ctx .tokenizer .as_ref() .ok_or_else(|| "gRPC PD router requires tokenizer".to_string())? .clone(); let reasoning_parser_factory = ctx .reasoning_parser_factory .as_ref() .ok_or_else(|| "gRPC PD router requires reasoning parser factory".to_string())? .clone(); let tool_parser_factory = ctx .tool_parser_factory .as_ref() .ok_or_else(|| "gRPC PD router requires tool parser factory".to_string())? .clone(); // Create shared components for pipeline let shared_components = Arc::new(SharedComponents { tokenizer: tokenizer.clone(), tool_parser_factory: tool_parser_factory.clone(), reasoning_parser_factory: reasoning_parser_factory.clone(), }); // Create PD pipeline let pipeline = RequestPipeline::new_pd( worker_registry.clone(), policy_registry.clone(), tokenizer.clone(), tool_parser_factory.clone(), reasoning_parser_factory.clone(), ctx.configured_tool_parser.clone(), ctx.configured_reasoning_parser.clone(), ); Ok(GrpcPDRouter { worker_registry, pipeline, shared_components, }) } /// Main route_generate implementation with PD dual dispatch async fn route_generate_impl( &self, headers: Option<&HeaderMap>, body: &GenerateRequest, model_id: Option<&str>, ) -> Response { debug!( "Processing generate request for model: {:?} (PD mode)", model_id ); // Use pipeline for ALL requests (streaming and non-streaming) self.pipeline .execute_generate( Arc::new(body.clone()), headers.cloned(), model_id.map(|s| s.to_string()), self.shared_components.clone(), ) .await } /// Main route_chat implementation with PD dual dispatch async fn route_chat_impl( &self, headers: Option<&HeaderMap>, body: &ChatCompletionRequest, model_id: Option<&str>, ) -> Response { debug!( "Processing chat completion request for model: {:?} (PD mode)", model_id ); // Use pipeline for ALL requests (streaming and non-streaming) self.pipeline .execute_chat( Arc::new(body.clone()), headers.cloned(), model_id.map(|s| s.to_string()), self.shared_components.clone(), ) .await } } impl std::fmt::Debug for GrpcPDRouter { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let prefill_workers = self.worker_registry.get_workers_filtered( None, Some(WorkerType::Prefill { bootstrap_port: None, }), Some(ConnectionMode::Grpc { port: None }), false, ); let decode_workers = self.worker_registry.get_workers_filtered( None, Some(WorkerType::Decode), Some(ConnectionMode::Grpc { port: None }), false, ); f.debug_struct("GrpcPDRouter") .field("prefill_workers_count", &prefill_workers.len()) .field("decode_workers_count", &decode_workers.len()) .finish() } } #[async_trait] impl RouterTrait for GrpcPDRouter { fn as_any(&self) -> &dyn std::any::Any { self } async fn health_generate(&self, _req: Request) -> Response { ( StatusCode::NOT_IMPLEMENTED, "Health generate not yet implemented for gRPC PD", ) .into_response() } async fn get_server_info(&self, _req: Request) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } async fn get_models(&self, _req: Request) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } async fn get_model_info(&self, _req: Request) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } async fn route_generate( &self, headers: Option<&HeaderMap>, body: &GenerateRequest, model_id: Option<&str>, ) -> Response { self.route_generate_impl(headers, body, model_id).await } async fn route_chat( &self, headers: Option<&HeaderMap>, body: &ChatCompletionRequest, model_id: Option<&str>, ) -> Response { self.route_chat_impl(headers, body, model_id).await } async fn route_completion( &self, _headers: Option<&HeaderMap>, _body: &CompletionRequest, _model_id: Option<&str>, ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } async fn route_responses( &self, _headers: Option<&HeaderMap>, _body: &ResponsesRequest, _model_id: Option<&str>, ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } async fn get_response( &self, _headers: Option<&HeaderMap>, _response_id: &str, _params: &ResponsesGetParams, ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } async fn route_embeddings( &self, _headers: Option<&HeaderMap>, _body: &EmbeddingRequest, _model_id: Option<&str>, ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } async fn route_classify( &self, _headers: Option<&HeaderMap>, _body: &ClassifyRequest, _model_id: Option<&str>, ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } async fn route_rerank( &self, _headers: Option<&HeaderMap>, _body: &RerankRequest, _model_id: Option<&str>, ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } fn router_type(&self) -> &'static str { "grpc_pd" } }