router.rs 10.8 KB
Newer Older
1
use std::sync::Arc;
2
3
4
5
6

use async_trait::async_trait;
use axum::{
    body::Body,
    extract::Request,
7
    http::{HeaderMap, StatusCode},
8
9
    response::{IntoResponse, Response},
};
10
use tracing::debug;
11

12
use super::{
13
14
15
16
    common::responses::{
        handlers::{cancel_response_impl, get_response_impl},
        utils::validate_worker_availability,
    },
17
    context::SharedComponents,
18
19
20
21
    harmony::{
        serve_harmony_responses, serve_harmony_responses_stream, HarmonyDetector,
        HarmonyResponsesContext,
    },
22
    pipeline::RequestPipeline,
23
    regular::responses,
24
};
25
use crate::{
26
    app_context::AppContext,
27
28
29
    core::WorkerRegistry,
    protocols::{
        chat::ChatCompletionRequest,
30
        classify::ClassifyRequest,
31
32
33
34
35
36
37
38
        completion::CompletionRequest,
        embedding::EmbeddingRequest,
        generate::GenerateRequest,
        rerank::RerankRequest,
        responses::{ResponsesGetParams, ResponsesRequest},
    },
    routers::RouterTrait,
};
39

40
/// gRPC router implementation for SGLang
41
#[derive(Clone)]
42
#[allow(dead_code)]
43
pub struct GrpcRouter {
44
    worker_registry: Arc<WorkerRegistry>,
45
    pipeline: RequestPipeline,
46
    harmony_pipeline: RequestPipeline,
47
    shared_components: Arc<SharedComponents>,
48
    responses_context: responses::ResponsesContext,
49
    harmony_responses_context: responses::ResponsesContext,
50
}
51
52

impl GrpcRouter {
53
    /// Create a new gRPC router
54
    pub async fn new(ctx: &Arc<AppContext>) -> Result<Self, String> {
55
56
57
58
59
60
61
62
63
64
65
        // Extract necessary components from context
        let tokenizer = ctx
            .tokenizer
            .as_ref()
            .ok_or_else(|| "gRPC router requires tokenizer".to_string())?
            .clone();
        let reasoning_parser_factory = ctx
            .reasoning_parser_factory
            .as_ref()
            .ok_or_else(|| "gRPC router requires reasoning parser factory".to_string())?
            .clone();
66
67
68
69
70
        let tool_parser_factory = ctx
            .tool_parser_factory
            .as_ref()
            .ok_or_else(|| "gRPC router requires tool parser factory".to_string())?
            .clone();
71

72
        let worker_registry = ctx.worker_registry.clone();
73
        let _policy_registry = ctx.policy_registry.clone();
Chang Su's avatar
Chang Su committed
74

75
        // Create shared components for pipeline
76
        let shared_components = Arc::new(SharedComponents {
77
78
79
80
81
            tokenizer: tokenizer.clone(),
            tool_parser_factory: tool_parser_factory.clone(),
            reasoning_parser_factory: reasoning_parser_factory.clone(),
        });

82
        // Create regular pipeline
83
84
        let pipeline = RequestPipeline::new_regular(
            worker_registry.clone(),
85
            _policy_registry.clone(),
86
87
88
89
90
91
92
            tokenizer.clone(),
            tool_parser_factory.clone(),
            reasoning_parser_factory.clone(),
            ctx.configured_tool_parser.clone(),
            ctx.configured_reasoning_parser.clone(),
        );

93
94
        // Create Harmony pipelines
        let harmony_pipeline = RequestPipeline::new_harmony(
95
            worker_registry.clone(),
96
97
98
99
100
101
            _policy_registry.clone(),
            tokenizer.clone(),
            tool_parser_factory.clone(),
            reasoning_parser_factory.clone(),
            ctx.configured_tool_parser.clone(),
            ctx.configured_reasoning_parser.clone(),
102
103
        );

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        // Extract shared dependencies for responses contexts
        let mcp_manager = ctx
            .mcp_manager
            .get()
            .ok_or_else(|| "gRPC router requires MCP manager".to_string())?
            .clone();

        // Helper closure to create responses context with a given pipeline
        let create_responses_context = |pipeline: &RequestPipeline| {
            responses::ResponsesContext::new(
                Arc::new(pipeline.clone()),
                shared_components.clone(),
                worker_registry.clone(),
                ctx.response_storage.clone(),
                ctx.conversation_storage.clone(),
                ctx.conversation_item_storage.clone(),
                mcp_manager.clone(),
            )
        };

        // Create responses contexts for both pipelines
        let responses_context = create_responses_context(&pipeline);
        let harmony_responses_context = create_responses_context(&harmony_pipeline);

128
        Ok(GrpcRouter {
129
            worker_registry,
130
            pipeline,
131
            harmony_pipeline,
132
            shared_components,
133
            responses_context,
134
            harmony_responses_context,
135
136
        })
    }
137
138
139
140

    /// Main route_chat implementation
    async fn route_chat_impl(
        &self,
141
        headers: Option<&HeaderMap>,
142
143
144
        body: &ChatCompletionRequest,
        model_id: Option<&str>,
    ) -> Response {
145
146
147
        // Choose Harmony pipeline if model indicates Harmony
        let is_harmony = HarmonyDetector::is_harmony_model(&body.model);

148
        debug!(
149
150
            "Processing chat completion request for model: {:?}, using_harmony={}",
            model_id, is_harmony
151
152
        );

153
154
155
156
157
158
159
        let pipeline = if is_harmony {
            &self.harmony_pipeline
        } else {
            &self.pipeline
        };

        pipeline
160
            .execute_chat(
161
                Arc::new(body.clone()),
162
163
164
165
166
                headers.cloned(),
                model_id.map(|s| s.to_string()),
                self.shared_components.clone(),
            )
            .await
167
168
    }

169
170
171
    /// Main route_generate implementation
    async fn route_generate_impl(
        &self,
172
        headers: Option<&HeaderMap>,
173
174
175
176
177
        body: &GenerateRequest,
        model_id: Option<&str>,
    ) -> Response {
        debug!("Processing generate request for model: {:?}", model_id);

178
179
        self.pipeline
            .execute_generate(
180
                Arc::new(body.clone()),
181
182
183
                headers.cloned(),
                model_id.map(|s| s.to_string()),
                self.shared_components.clone(),
184
            )
185
            .await
186
    }
187

188
189
190
    /// Main route_responses implementation
    ///
    /// Routes to either Harmony or regular responses implementation based on model detection
191
192
    async fn route_responses_impl(
        &self,
193
        headers: Option<&HeaderMap>,
194
195
196
        body: &ResponsesRequest,
        model_id: Option<&str>,
    ) -> Response {
197
198
199
200
201
202
203
204
205
        // 0. Fast worker validation (fail-fast before expensive operations)
        let requested_model: Option<&str> = model_id.or(Some(body.model.as_str()));

        if let Some(error_response) = requested_model
            .and_then(|model| validate_worker_availability(&self.worker_registry, model))
        {
            return error_response;
        }

206
207
        // Choose implementation based on Harmony model detection
        let is_harmony = HarmonyDetector::is_harmony_model(&body.model);
208

209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        if is_harmony {
            debug!(
                "Processing Harmony responses request for model: {:?}, streaming: {:?}",
                model_id, body.stream
            );
            let harmony_ctx = HarmonyResponsesContext::new(
                Arc::new(self.harmony_pipeline.clone()),
                self.shared_components.clone(),
                self.harmony_responses_context.mcp_manager.clone(),
                self.harmony_responses_context.response_storage.clone(),
            );

            if body.stream.unwrap_or(false) {
                serve_harmony_responses_stream(&harmony_ctx, body.clone()).await
            } else {
                match serve_harmony_responses(&harmony_ctx, body.clone()).await {
                    Ok(response) => axum::Json(response).into_response(),
                    Err(error_response) => error_response,
                }
228
            }
229
230
231
232
233
234
235
236
        } else {
            responses::route_responses(
                &self.responses_context,
                Arc::new(body.clone()),
                headers.cloned(),
                model_id.map(|s| s.to_string()),
            )
            .await
237
238
        }
    }
239
240
241
242
}

impl std::fmt::Debug for GrpcRouter {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243
        let stats = self.worker_registry.stats();
244
        f.debug_struct("GrpcRouter")
245
            .field("workers_count", &stats.total_workers)
246
            .finish()
247
248
249
250
251
252
253
254
255
256
    }
}

#[async_trait]
impl RouterTrait for GrpcRouter {
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

    async fn health_generate(&self, _req: Request<Body>) -> Response {
257
258
259
260
261
        (
            StatusCode::NOT_IMPLEMENTED,
            "Health generate not yet implemented for gRPC",
        )
            .into_response()
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
    }

    async fn get_server_info(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn get_models(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn get_model_info(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn route_generate(
        &self,
278
279
280
        headers: Option<&HeaderMap>,
        body: &GenerateRequest,
        model_id: Option<&str>,
281
    ) -> Response {
282
        self.route_generate_impl(headers, body, model_id).await
283
284
285
286
    }

    async fn route_chat(
        &self,
287
        headers: Option<&HeaderMap>,
288
        body: &ChatCompletionRequest,
289
        model_id: Option<&str>,
290
    ) -> Response {
291
        self.route_chat_impl(headers, body, model_id).await
292
293
294
295
296
    }

    async fn route_completion(
        &self,
        _headers: Option<&HeaderMap>,
297
        _body: &CompletionRequest,
298
        _model_id: Option<&str>,
299
300
301
302
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

303
304
    async fn route_responses(
        &self,
305
306
307
        headers: Option<&HeaderMap>,
        body: &ResponsesRequest,
        model_id: Option<&str>,
308
    ) -> Response {
309
        self.route_responses_impl(headers, body, model_id).await
310
311
    }

312
313
314
    async fn get_response(
        &self,
        _headers: Option<&HeaderMap>,
315
        response_id: &str,
316
        _params: &ResponsesGetParams,
317
    ) -> Response {
318
        get_response_impl(&self.responses_context, response_id).await
319
320
    }

321
    async fn cancel_response(&self, _headers: Option<&HeaderMap>, response_id: &str) -> Response {
322
        cancel_response_impl(&self.responses_context, response_id).await
323
324
    }

325
    async fn route_embeddings(
326
327
        &self,
        _headers: Option<&HeaderMap>,
328
        _body: &EmbeddingRequest,
329
330
331
332
333
        _model_id: Option<&str>,
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

334
    async fn route_classify(
335
336
        &self,
        _headers: Option<&HeaderMap>,
337
        _body: &ClassifyRequest,
338
339
        _model_id: Option<&str>,
    ) -> Response {
340
341
342
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

343
344
345
    async fn route_rerank(
        &self,
        _headers: Option<&HeaderMap>,
346
        _body: &RerankRequest,
347
        _model_id: Option<&str>,
348
    ) -> Response {
349
350
351
352
353
354
355
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    fn router_type(&self) -> &'static str {
        "grpc"
    }
}