router.rs 10.9 KB
Newer Older
1
2
// gRPC Router Implementation

3
use std::{collections::HashMap, sync::Arc};
4
5
6
7
8

use async_trait::async_trait;
use axum::{
    body::Body,
    extract::Request,
9
    http::{HeaderMap, StatusCode},
10
11
    response::{IntoResponse, Response},
};
12
use tokio::sync::RwLock;
13
use tracing::debug;
14

15
16
17
18
19
use super::{
    context::SharedComponents,
    pipeline::RequestPipeline,
    responses::{self, BackgroundTaskInfo},
};
20
use crate::{
21
    app_context::AppContext,
22
23
    config::types::RetryConfig,
    core::WorkerRegistry,
24
25
26
    data_connector::{
        SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
    },
27
28
29
    policies::PolicyRegistry,
    protocols::{
        chat::ChatCompletionRequest,
30
        classify::ClassifyRequest,
31
32
33
34
35
36
37
38
39
40
41
        completion::CompletionRequest,
        embedding::EmbeddingRequest,
        generate::GenerateRequest,
        rerank::RerankRequest,
        responses::{ResponsesGetParams, ResponsesRequest},
    },
    reasoning_parser::ParserFactory as ReasoningParserFactory,
    routers::RouterTrait,
    tokenizer::traits::Tokenizer,
    tool_parser::ParserFactory as ToolParserFactory,
};
42

43
/// gRPC router implementation for SGLang
44
#[derive(Clone)]
45
#[allow(dead_code)]
46
pub struct GrpcRouter {
47
48
    worker_registry: Arc<WorkerRegistry>,
    policy_registry: Arc<PolicyRegistry>,
49
    tokenizer: Arc<dyn Tokenizer>,
50
    reasoning_parser_factory: ReasoningParserFactory,
51
    tool_parser_factory: ToolParserFactory,
52
53
54
    dp_aware: bool,
    api_key: Option<String>,
    retry_config: RetryConfig,
55
56
    configured_reasoning_parser: Option<String>,
    configured_tool_parser: Option<String>,
57
58
    pipeline: RequestPipeline,
    shared_components: Arc<SharedComponents>,
59
60
61
62
63
64
65
66
    // Storage backends for /v1/responses support
    response_storage: SharedResponseStorage,
    conversation_storage: SharedConversationStorage,
    conversation_item_storage: SharedConversationItemStorage,
    // Optional MCP manager for tool execution (enabled via SGLANG_MCP_CONFIG env var)
    mcp_manager: Option<Arc<crate::mcp::McpClientManager>>,
    // Background task handles for cancellation support (includes gRPC client for Python abort)
    background_tasks: Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>,
67
}
68
69

impl GrpcRouter {
70
    /// Create a new gRPC router
71
    pub async fn new(ctx: &Arc<AppContext>) -> Result<Self, String> {
72
73
74
75
76
77
78
79
80
81
82
        // Extract necessary components from context
        let tokenizer = ctx
            .tokenizer
            .as_ref()
            .ok_or_else(|| "gRPC router requires tokenizer".to_string())?
            .clone();
        let reasoning_parser_factory = ctx
            .reasoning_parser_factory
            .as_ref()
            .ok_or_else(|| "gRPC router requires reasoning parser factory".to_string())?
            .clone();
83
84
85
86
87
        let tool_parser_factory = ctx
            .tool_parser_factory
            .as_ref()
            .ok_or_else(|| "gRPC router requires tool parser factory".to_string())?
            .clone();
88

89
90
        let worker_registry = ctx.worker_registry.clone();
        let policy_registry = ctx.policy_registry.clone();
Chang Su's avatar
Chang Su committed
91

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        // Extract storage backends from context
        let response_storage = ctx.response_storage.clone();
        let conversation_storage = ctx.conversation_storage.clone();
        let conversation_item_storage = ctx.conversation_item_storage.clone();

        // Optional MCP manager activation via env var path (config-driven gate)
        let mcp_manager = match std::env::var("SGLANG_MCP_CONFIG").ok() {
            Some(path) if !path.trim().is_empty() => {
                match crate::mcp::McpConfig::from_file(&path).await {
                    Ok(cfg) => match crate::mcp::McpClientManager::new(cfg).await {
                        Ok(mgr) => Some(Arc::new(mgr)),
                        Err(err) => {
                            tracing::warn!("Failed to initialize MCP manager: {}", err);
                            None
                        }
                    },
                    Err(err) => {
                        tracing::warn!("Failed to load MCP config from '{}': {}", path, err);
                        None
                    }
                }
            }
            _ => None,
        };

117
        // Create shared components for pipeline
118
        let shared_components = Arc::new(SharedComponents {
119
120
121
122
123
            tokenizer: tokenizer.clone(),
            tool_parser_factory: tool_parser_factory.clone(),
            reasoning_parser_factory: reasoning_parser_factory.clone(),
        });

124
125
126
127
        // Create pipeline
        let pipeline = RequestPipeline::new_regular(
            worker_registry.clone(),
            policy_registry.clone(),
128
129
130
131
132
133
134
            tokenizer.clone(),
            tool_parser_factory.clone(),
            reasoning_parser_factory.clone(),
            ctx.configured_tool_parser.clone(),
            ctx.configured_reasoning_parser.clone(),
        );

135
        Ok(GrpcRouter {
136
137
            worker_registry,
            policy_registry,
138
139
            tokenizer,
            reasoning_parser_factory,
140
            tool_parser_factory,
141
142
143
            dp_aware: ctx.router_config.dp_aware,
            api_key: ctx.router_config.api_key.clone(),
            retry_config: ctx.router_config.effective_retry_config(),
144
145
            configured_reasoning_parser: ctx.configured_reasoning_parser.clone(),
            configured_tool_parser: ctx.configured_tool_parser.clone(),
146
147
            pipeline,
            shared_components,
148
149
150
151
152
            response_storage,
            conversation_storage,
            conversation_item_storage,
            mcp_manager,
            background_tasks: Arc::new(RwLock::new(HashMap::new())),
153
154
        })
    }
155
156
157
158

    /// Main route_chat implementation
    async fn route_chat_impl(
        &self,
159
        headers: Option<&HeaderMap>,
160
161
162
163
164
165
166
167
        body: &ChatCompletionRequest,
        model_id: Option<&str>,
    ) -> Response {
        debug!(
            "Processing chat completion request for model: {:?}",
            model_id
        );

168
169
170
        // Use pipeline for ALL requests (streaming and non-streaming)
        self.pipeline
            .execute_chat(
171
                Arc::new(body.clone()),
172
173
174
175
176
                headers.cloned(),
                model_id.map(|s| s.to_string()),
                self.shared_components.clone(),
            )
            .await
177
178
    }

179
180
181
    /// Main route_generate implementation
    async fn route_generate_impl(
        &self,
182
        headers: Option<&HeaderMap>,
183
184
185
186
187
        body: &GenerateRequest,
        model_id: Option<&str>,
    ) -> Response {
        debug!("Processing generate request for model: {:?}", model_id);

188
189
190
        // Use pipeline for ALL requests (streaming and non-streaming)
        self.pipeline
            .execute_generate(
191
                Arc::new(body.clone()),
192
193
194
                headers.cloned(),
                model_id.map(|s| s.to_string()),
                self.shared_components.clone(),
195
            )
196
            .await
197
    }
198
199
200
201
}

impl std::fmt::Debug for GrpcRouter {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202
        let stats = self.worker_registry.stats();
203
        f.debug_struct("GrpcRouter")
204
            .field("workers_count", &stats.total_workers)
205
206
            .field("dp_aware", &self.dp_aware)
            .finish()
207
208
209
210
211
212
213
214
215
216
    }
}

#[async_trait]
impl RouterTrait for GrpcRouter {
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

    async fn health_generate(&self, _req: Request<Body>) -> Response {
217
218
219
220
221
222
        // TODO: Implement actual generation test for gRPC
        (
            StatusCode::NOT_IMPLEMENTED,
            "Health generate not yet implemented for gRPC",
        )
            .into_response()
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
    }

    async fn get_server_info(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn get_models(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn get_model_info(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn route_generate(
        &self,
239
240
241
        headers: Option<&HeaderMap>,
        body: &GenerateRequest,
        model_id: Option<&str>,
242
    ) -> Response {
243
        self.route_generate_impl(headers, body, model_id).await
244
245
246
247
    }

    async fn route_chat(
        &self,
248
        headers: Option<&HeaderMap>,
249
        body: &ChatCompletionRequest,
250
        model_id: Option<&str>,
251
    ) -> Response {
252
        self.route_chat_impl(headers, body, model_id).await
253
254
255
256
257
    }

    async fn route_completion(
        &self,
        _headers: Option<&HeaderMap>,
258
        _body: &CompletionRequest,
259
        _model_id: Option<&str>,
260
261
262
263
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

264
265
    async fn route_responses(
        &self,
266
267
268
        headers: Option<&HeaderMap>,
        body: &ResponsesRequest,
        model_id: Option<&str>,
269
    ) -> Response {
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
        // Use responses module for ALL requests (streaming and non-streaming)
        // Responses module handles:
        // - Request validation (previous_response_id XOR conversation)
        // - Loading response chain / conversation history from storage
        // - Conversion: ResponsesRequest → ChatCompletionRequest
        // - Execution through chat pipeline stages
        // - Conversion: ChatCompletionResponse → ResponsesResponse
        // - Response persistence
        // - MCP tool loop wrapper (future)
        responses::route_responses(
            &self.pipeline,
            Arc::new(body.clone()),
            headers.cloned(),
            model_id.map(|s| s.to_string()),
            self.shared_components.clone(),
            self.response_storage.clone(),
            self.conversation_storage.clone(),
            self.conversation_item_storage.clone(),
            self.background_tasks.clone(),
        )
        .await
291
292
    }

293
294
295
    async fn get_response(
        &self,
        _headers: Option<&HeaderMap>,
296
        response_id: &str,
297
        _params: &ResponsesGetParams,
298
    ) -> Response {
299
        responses::get_response_impl(&self.response_storage, response_id).await
300
301
    }

302
303
304
    async fn cancel_response(&self, _headers: Option<&HeaderMap>, response_id: &str) -> Response {
        responses::cancel_response_impl(&self.response_storage, &self.background_tasks, response_id)
            .await
305
306
    }

307
308
309
310
311
312
313
314
315
    async fn route_classify(
        &self,
        _headers: Option<&HeaderMap>,
        _body: &ClassifyRequest,
        _model_id: Option<&str>,
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

316
317
318
    async fn route_embeddings(
        &self,
        _headers: Option<&HeaderMap>,
319
        _body: &EmbeddingRequest,
320
321
        _model_id: Option<&str>,
    ) -> Response {
322
323
324
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

325
326
327
    async fn route_rerank(
        &self,
        _headers: Option<&HeaderMap>,
328
        _body: &RerankRequest,
329
        _model_id: Option<&str>,
330
    ) -> Response {
331
332
333
334
335
336
337
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    fn router_type(&self) -> &'static str {
        "grpc"
    }
}