router.rs 8.83 KB
Newer Older
1
2
// gRPC Router Implementation

3
use std::sync::Arc;
4
5
6
7
8

use async_trait::async_trait;
use axum::{
    body::Body,
    extract::Request,
9
    http::{HeaderMap, StatusCode},
10
11
    response::{IntoResponse, Response},
};
12
use tracing::debug;
13

14
use super::{context::SharedComponents, pipeline::RequestPipeline, responses};
15
use crate::{
16
    app_context::AppContext,
17
18
19
20
21
    config::types::RetryConfig,
    core::WorkerRegistry,
    policies::PolicyRegistry,
    protocols::{
        chat::ChatCompletionRequest,
22
        classify::ClassifyRequest,
23
24
25
26
27
28
29
30
31
32
33
        completion::CompletionRequest,
        embedding::EmbeddingRequest,
        generate::GenerateRequest,
        rerank::RerankRequest,
        responses::{ResponsesGetParams, ResponsesRequest},
    },
    reasoning_parser::ParserFactory as ReasoningParserFactory,
    routers::RouterTrait,
    tokenizer::traits::Tokenizer,
    tool_parser::ParserFactory as ToolParserFactory,
};
34

35
/// gRPC router implementation for SGLang
36
#[derive(Clone)]
37
#[allow(dead_code)]
38
pub struct GrpcRouter {
39
40
    worker_registry: Arc<WorkerRegistry>,
    policy_registry: Arc<PolicyRegistry>,
41
    tokenizer: Arc<dyn Tokenizer>,
42
    reasoning_parser_factory: ReasoningParserFactory,
43
    tool_parser_factory: ToolParserFactory,
44
45
46
    dp_aware: bool,
    api_key: Option<String>,
    retry_config: RetryConfig,
47
48
    configured_reasoning_parser: Option<String>,
    configured_tool_parser: Option<String>,
49
50
    pipeline: RequestPipeline,
    shared_components: Arc<SharedComponents>,
51
52
    // Responses context (bundles all /v1/responses dependencies: storage, MCP, background_tasks)
    responses_context: responses::ResponsesContext,
53
}
54
55

impl GrpcRouter {
56
    /// Create a new gRPC router
57
    pub async fn new(ctx: &Arc<AppContext>) -> Result<Self, String> {
58
59
60
61
62
63
64
65
66
67
68
        // Extract necessary components from context
        let tokenizer = ctx
            .tokenizer
            .as_ref()
            .ok_or_else(|| "gRPC router requires tokenizer".to_string())?
            .clone();
        let reasoning_parser_factory = ctx
            .reasoning_parser_factory
            .as_ref()
            .ok_or_else(|| "gRPC router requires reasoning parser factory".to_string())?
            .clone();
69
70
71
72
73
        let tool_parser_factory = ctx
            .tool_parser_factory
            .as_ref()
            .ok_or_else(|| "gRPC router requires tool parser factory".to_string())?
            .clone();
74

75
76
        let worker_registry = ctx.worker_registry.clone();
        let policy_registry = ctx.policy_registry.clone();
Chang Su's avatar
Chang Su committed
77

78
        // Create shared components for pipeline
79
        let shared_components = Arc::new(SharedComponents {
80
81
82
83
84
            tokenizer: tokenizer.clone(),
            tool_parser_factory: tool_parser_factory.clone(),
            reasoning_parser_factory: reasoning_parser_factory.clone(),
        });

85
86
87
88
        // Create pipeline
        let pipeline = RequestPipeline::new_regular(
            worker_registry.clone(),
            policy_registry.clone(),
89
90
91
92
93
94
95
            tokenizer.clone(),
            tool_parser_factory.clone(),
            reasoning_parser_factory.clone(),
            ctx.configured_tool_parser.clone(),
            ctx.configured_reasoning_parser.clone(),
        );

96
97
98
99
100
101
102
103
104
105
106
107
108
109
        // Create responses context with all dependencies
        let responses_context = responses::ResponsesContext::new(
            Arc::new(pipeline.clone()),
            shared_components.clone(),
            worker_registry.clone(),
            ctx.response_storage.clone(),
            ctx.conversation_storage.clone(),
            ctx.conversation_item_storage.clone(),
            ctx.mcp_manager
                .get()
                .ok_or_else(|| "gRPC router requires MCP manager".to_string())?
                .clone(),
        );

110
        Ok(GrpcRouter {
111
112
            worker_registry,
            policy_registry,
113
114
            tokenizer,
            reasoning_parser_factory,
115
            tool_parser_factory,
116
117
118
            dp_aware: ctx.router_config.dp_aware,
            api_key: ctx.router_config.api_key.clone(),
            retry_config: ctx.router_config.effective_retry_config(),
119
120
            configured_reasoning_parser: ctx.configured_reasoning_parser.clone(),
            configured_tool_parser: ctx.configured_tool_parser.clone(),
121
122
            pipeline,
            shared_components,
123
            responses_context,
124
125
        })
    }
126
127
128
129

    /// Main route_chat implementation
    async fn route_chat_impl(
        &self,
130
        headers: Option<&HeaderMap>,
131
132
133
134
135
136
137
138
        body: &ChatCompletionRequest,
        model_id: Option<&str>,
    ) -> Response {
        debug!(
            "Processing chat completion request for model: {:?}",
            model_id
        );

139
140
141
        // Use pipeline for ALL requests (streaming and non-streaming)
        self.pipeline
            .execute_chat(
142
                Arc::new(body.clone()),
143
144
145
146
147
                headers.cloned(),
                model_id.map(|s| s.to_string()),
                self.shared_components.clone(),
            )
            .await
148
149
    }

150
151
152
    /// Main route_generate implementation
    async fn route_generate_impl(
        &self,
153
        headers: Option<&HeaderMap>,
154
155
156
157
158
        body: &GenerateRequest,
        model_id: Option<&str>,
    ) -> Response {
        debug!("Processing generate request for model: {:?}", model_id);

159
160
161
        // Use pipeline for ALL requests (streaming and non-streaming)
        self.pipeline
            .execute_generate(
162
                Arc::new(body.clone()),
163
164
165
                headers.cloned(),
                model_id.map(|s| s.to_string()),
                self.shared_components.clone(),
166
            )
167
            .await
168
    }
169
170
171
172
}

impl std::fmt::Debug for GrpcRouter {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173
        let stats = self.worker_registry.stats();
174
        f.debug_struct("GrpcRouter")
175
            .field("workers_count", &stats.total_workers)
176
177
            .field("dp_aware", &self.dp_aware)
            .finish()
178
179
180
181
182
183
184
185
186
187
    }
}

#[async_trait]
impl RouterTrait for GrpcRouter {
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

    async fn health_generate(&self, _req: Request<Body>) -> Response {
188
189
190
191
192
193
        // TODO: Implement actual generation test for gRPC
        (
            StatusCode::NOT_IMPLEMENTED,
            "Health generate not yet implemented for gRPC",
        )
            .into_response()
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    }

    async fn get_server_info(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn get_models(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn get_model_info(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn route_generate(
        &self,
210
211
212
        headers: Option<&HeaderMap>,
        body: &GenerateRequest,
        model_id: Option<&str>,
213
    ) -> Response {
214
        self.route_generate_impl(headers, body, model_id).await
215
216
217
218
    }

    async fn route_chat(
        &self,
219
        headers: Option<&HeaderMap>,
220
        body: &ChatCompletionRequest,
221
        model_id: Option<&str>,
222
    ) -> Response {
223
        self.route_chat_impl(headers, body, model_id).await
224
225
226
227
228
    }

    async fn route_completion(
        &self,
        _headers: Option<&HeaderMap>,
229
        _body: &CompletionRequest,
230
        _model_id: Option<&str>,
231
232
233
234
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

235
236
    async fn route_responses(
        &self,
237
238
239
        headers: Option<&HeaderMap>,
        body: &ResponsesRequest,
        model_id: Option<&str>,
240
    ) -> Response {
241
        responses::route_responses(
242
            &self.responses_context,
243
244
245
246
247
            Arc::new(body.clone()),
            headers.cloned(),
            model_id.map(|s| s.to_string()),
        )
        .await
248
249
    }

250
251
252
    async fn get_response(
        &self,
        _headers: Option<&HeaderMap>,
253
        response_id: &str,
254
        _params: &ResponsesGetParams,
255
    ) -> Response {
256
        responses::get_response_impl(&self.responses_context, response_id).await
257
258
    }

259
    async fn cancel_response(&self, _headers: Option<&HeaderMap>, response_id: &str) -> Response {
260
        responses::cancel_response_impl(&self.responses_context, response_id).await
261
262
    }

263
264
265
266
267
268
269
270
271
    async fn route_classify(
        &self,
        _headers: Option<&HeaderMap>,
        _body: &ClassifyRequest,
        _model_id: Option<&str>,
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

272
273
274
    async fn route_embeddings(
        &self,
        _headers: Option<&HeaderMap>,
275
        _body: &EmbeddingRequest,
276
277
        _model_id: Option<&str>,
    ) -> Response {
278
279
280
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

281
282
283
    async fn route_rerank(
        &self,
        _headers: Option<&HeaderMap>,
284
        _body: &RerankRequest,
285
        _model_id: Option<&str>,
286
    ) -> Response {
287
288
289
290
291
292
293
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    fn router_type(&self) -> &'static str {
        "grpc"
    }
}