router.rs 10.2 KB
Newer Older
1
2
// gRPC Router Implementation

3
use std::{collections::HashMap, sync::Arc};
4
5
6
7
8

use async_trait::async_trait;
use axum::{
    body::Body,
    extract::Request,
9
    http::{HeaderMap, StatusCode},
10
11
    response::{IntoResponse, Response},
};
12
use tokio::sync::RwLock;
13
use tracing::debug;
14

15
16
17
18
19
use super::{
    context::SharedComponents,
    pipeline::RequestPipeline,
    responses::{self, BackgroundTaskInfo},
};
20
use crate::{
21
    app_context::AppContext,
22
23
    config::types::RetryConfig,
    core::WorkerRegistry,
24
25
26
    data_connector::{
        SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
    },
27
    mcp::McpManager,
28
29
30
    policies::PolicyRegistry,
    protocols::{
        chat::ChatCompletionRequest,
31
        classify::ClassifyRequest,
32
33
34
35
36
37
38
39
40
41
42
        completion::CompletionRequest,
        embedding::EmbeddingRequest,
        generate::GenerateRequest,
        rerank::RerankRequest,
        responses::{ResponsesGetParams, ResponsesRequest},
    },
    reasoning_parser::ParserFactory as ReasoningParserFactory,
    routers::RouterTrait,
    tokenizer::traits::Tokenizer,
    tool_parser::ParserFactory as ToolParserFactory,
};
43

44
/// gRPC router implementation for SGLang
45
#[derive(Clone)]
46
#[allow(dead_code)]
47
pub struct GrpcRouter {
48
49
    worker_registry: Arc<WorkerRegistry>,
    policy_registry: Arc<PolicyRegistry>,
50
    tokenizer: Arc<dyn Tokenizer>,
51
    reasoning_parser_factory: ReasoningParserFactory,
52
    tool_parser_factory: ToolParserFactory,
53
54
55
    dp_aware: bool,
    api_key: Option<String>,
    retry_config: RetryConfig,
56
57
    configured_reasoning_parser: Option<String>,
    configured_tool_parser: Option<String>,
58
59
    pipeline: RequestPipeline,
    shared_components: Arc<SharedComponents>,
60
61
62
63
    // Storage backends for /v1/responses support
    response_storage: SharedResponseStorage,
    conversation_storage: SharedConversationStorage,
    conversation_item_storage: SharedConversationItemStorage,
64
    mcp_manager: Arc<McpManager>,
65
66
    // Background task handles for cancellation support (includes gRPC client for Python abort)
    background_tasks: Arc<RwLock<HashMap<String, BackgroundTaskInfo>>>,
67
}
68
69

impl GrpcRouter {
70
    /// Create a new gRPC router
71
    pub async fn new(ctx: &Arc<AppContext>) -> Result<Self, String> {
72
73
74
75
76
77
78
79
80
81
82
        // Extract necessary components from context
        let tokenizer = ctx
            .tokenizer
            .as_ref()
            .ok_or_else(|| "gRPC router requires tokenizer".to_string())?
            .clone();
        let reasoning_parser_factory = ctx
            .reasoning_parser_factory
            .as_ref()
            .ok_or_else(|| "gRPC router requires reasoning parser factory".to_string())?
            .clone();
83
84
85
86
87
        let tool_parser_factory = ctx
            .tool_parser_factory
            .as_ref()
            .ok_or_else(|| "gRPC router requires tool parser factory".to_string())?
            .clone();
88

89
90
        let worker_registry = ctx.worker_registry.clone();
        let policy_registry = ctx.policy_registry.clone();
Chang Su's avatar
Chang Su committed
91

92
93
94
95
96
        // Extract storage backends from context
        let response_storage = ctx.response_storage.clone();
        let conversation_storage = ctx.conversation_storage.clone();
        let conversation_item_storage = ctx.conversation_item_storage.clone();

97
98
99
100
101
102
        // Get MCP manager from app context
        let mcp_manager = ctx
            .mcp_manager
            .get()
            .ok_or_else(|| "gRPC router requires MCP manager".to_string())?
            .clone();
103

104
        // Create shared components for pipeline
105
        let shared_components = Arc::new(SharedComponents {
106
107
108
109
110
            tokenizer: tokenizer.clone(),
            tool_parser_factory: tool_parser_factory.clone(),
            reasoning_parser_factory: reasoning_parser_factory.clone(),
        });

111
112
113
114
        // Create pipeline
        let pipeline = RequestPipeline::new_regular(
            worker_registry.clone(),
            policy_registry.clone(),
115
116
117
118
119
120
121
            tokenizer.clone(),
            tool_parser_factory.clone(),
            reasoning_parser_factory.clone(),
            ctx.configured_tool_parser.clone(),
            ctx.configured_reasoning_parser.clone(),
        );

122
        Ok(GrpcRouter {
123
124
            worker_registry,
            policy_registry,
125
126
            tokenizer,
            reasoning_parser_factory,
127
            tool_parser_factory,
128
129
130
            dp_aware: ctx.router_config.dp_aware,
            api_key: ctx.router_config.api_key.clone(),
            retry_config: ctx.router_config.effective_retry_config(),
131
132
            configured_reasoning_parser: ctx.configured_reasoning_parser.clone(),
            configured_tool_parser: ctx.configured_tool_parser.clone(),
133
134
            pipeline,
            shared_components,
135
136
137
138
139
            response_storage,
            conversation_storage,
            conversation_item_storage,
            mcp_manager,
            background_tasks: Arc::new(RwLock::new(HashMap::new())),
140
141
        })
    }
142
143
144
145

    /// Main route_chat implementation
    async fn route_chat_impl(
        &self,
146
        headers: Option<&HeaderMap>,
147
148
149
150
151
152
153
154
        body: &ChatCompletionRequest,
        model_id: Option<&str>,
    ) -> Response {
        debug!(
            "Processing chat completion request for model: {:?}",
            model_id
        );

155
156
157
        // Use pipeline for ALL requests (streaming and non-streaming)
        self.pipeline
            .execute_chat(
158
                Arc::new(body.clone()),
159
160
161
162
163
                headers.cloned(),
                model_id.map(|s| s.to_string()),
                self.shared_components.clone(),
            )
            .await
164
165
    }

166
167
168
    /// Main route_generate implementation
    async fn route_generate_impl(
        &self,
169
        headers: Option<&HeaderMap>,
170
171
172
173
174
        body: &GenerateRequest,
        model_id: Option<&str>,
    ) -> Response {
        debug!("Processing generate request for model: {:?}", model_id);

175
176
177
        // Use pipeline for ALL requests (streaming and non-streaming)
        self.pipeline
            .execute_generate(
178
                Arc::new(body.clone()),
179
180
181
                headers.cloned(),
                model_id.map(|s| s.to_string()),
                self.shared_components.clone(),
182
            )
183
            .await
184
    }
185
186
187
188
}

impl std::fmt::Debug for GrpcRouter {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189
        let stats = self.worker_registry.stats();
190
        f.debug_struct("GrpcRouter")
191
            .field("workers_count", &stats.total_workers)
192
193
            .field("dp_aware", &self.dp_aware)
            .finish()
194
195
196
197
198
199
200
201
202
203
    }
}

#[async_trait]
impl RouterTrait for GrpcRouter {
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

    async fn health_generate(&self, _req: Request<Body>) -> Response {
204
205
206
207
208
209
        // TODO: Implement actual generation test for gRPC
        (
            StatusCode::NOT_IMPLEMENTED,
            "Health generate not yet implemented for gRPC",
        )
            .into_response()
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    }

    async fn get_server_info(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn get_models(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn get_model_info(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn route_generate(
        &self,
226
227
228
        headers: Option<&HeaderMap>,
        body: &GenerateRequest,
        model_id: Option<&str>,
229
    ) -> Response {
230
        self.route_generate_impl(headers, body, model_id).await
231
232
233
234
    }

    async fn route_chat(
        &self,
235
        headers: Option<&HeaderMap>,
236
        body: &ChatCompletionRequest,
237
        model_id: Option<&str>,
238
    ) -> Response {
239
        self.route_chat_impl(headers, body, model_id).await
240
241
242
243
244
    }

    async fn route_completion(
        &self,
        _headers: Option<&HeaderMap>,
245
        _body: &CompletionRequest,
246
        _model_id: Option<&str>,
247
248
249
250
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

251
252
    async fn route_responses(
        &self,
253
254
255
        headers: Option<&HeaderMap>,
        body: &ResponsesRequest,
        model_id: Option<&str>,
256
    ) -> Response {
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
        // Use responses module for ALL requests (streaming and non-streaming)
        // Responses module handles:
        // - Request validation (previous_response_id XOR conversation)
        // - Loading response chain / conversation history from storage
        // - Conversion: ResponsesRequest → ChatCompletionRequest
        // - Execution through chat pipeline stages
        // - Conversion: ChatCompletionResponse → ResponsesResponse
        // - Response persistence
        // - MCP tool loop wrapper (future)
        responses::route_responses(
            &self.pipeline,
            Arc::new(body.clone()),
            headers.cloned(),
            model_id.map(|s| s.to_string()),
            self.shared_components.clone(),
            self.response_storage.clone(),
            self.conversation_storage.clone(),
            self.conversation_item_storage.clone(),
275
            self.mcp_manager.clone(),
276
277
278
            self.background_tasks.clone(),
        )
        .await
279
280
    }

281
282
283
    async fn get_response(
        &self,
        _headers: Option<&HeaderMap>,
284
        response_id: &str,
285
        _params: &ResponsesGetParams,
286
    ) -> Response {
287
        responses::get_response_impl(&self.response_storage, response_id).await
288
289
    }

290
291
292
    async fn cancel_response(&self, _headers: Option<&HeaderMap>, response_id: &str) -> Response {
        responses::cancel_response_impl(&self.response_storage, &self.background_tasks, response_id)
            .await
293
294
    }

295
296
297
298
299
300
301
302
303
    async fn route_classify(
        &self,
        _headers: Option<&HeaderMap>,
        _body: &ClassifyRequest,
        _model_id: Option<&str>,
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

304
305
306
    async fn route_embeddings(
        &self,
        _headers: Option<&HeaderMap>,
307
        _body: &EmbeddingRequest,
308
309
        _model_id: Option<&str>,
    ) -> Response {
310
311
312
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

313
314
315
    async fn route_rerank(
        &self,
        _headers: Option<&HeaderMap>,
316
        _body: &RerankRequest,
317
        _model_id: Option<&str>,
318
    ) -> Response {
319
320
321
322
323
324
325
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    fn router_type(&self) -> &'static str {
        "grpc"
    }
}