router.rs 10.8 KB
Newer Older
1
2
// gRPC Router Implementation

3
use std::sync::Arc;
4
5
6
7
8

use async_trait::async_trait;
use axum::{
    body::Body,
    extract::Request,
9
    http::{HeaderMap, StatusCode},
10
11
    response::{IntoResponse, Response},
};
12
use tracing::debug;
13

14
15
16
17
18
19
use super::{
    context::SharedComponents,
    harmony::{serve_harmony_responses, HarmonyDetector, HarmonyResponsesContext},
    pipeline::RequestPipeline,
    responses,
};
20
use crate::{
21
    app_context::AppContext,
22
23
24
    core::WorkerRegistry,
    protocols::{
        chat::ChatCompletionRequest,
25
        classify::ClassifyRequest,
26
27
28
29
30
31
32
33
        completion::CompletionRequest,
        embedding::EmbeddingRequest,
        generate::GenerateRequest,
        rerank::RerankRequest,
        responses::{ResponsesGetParams, ResponsesRequest},
    },
    routers::RouterTrait,
};
34

35
/// gRPC router implementation for SGLang
36
#[derive(Clone)]
37
#[allow(dead_code)]
38
pub struct GrpcRouter {
39
    worker_registry: Arc<WorkerRegistry>,
40
    pipeline: RequestPipeline,
41
    harmony_pipeline: RequestPipeline,
42
    shared_components: Arc<SharedComponents>,
43
44
    // Responses context (bundles all /v1/responses dependencies: storage, MCP, background_tasks)
    responses_context: responses::ResponsesContext,
45
46
    // Harmony responses context (uses harmony pipeline)
    harmony_responses_context: responses::ResponsesContext,
47
}
48
49

impl GrpcRouter {
50
    /// Create a new gRPC router
51
    pub async fn new(ctx: &Arc<AppContext>) -> Result<Self, String> {
52
53
54
55
56
57
58
59
60
61
62
        // Extract necessary components from context
        let tokenizer = ctx
            .tokenizer
            .as_ref()
            .ok_or_else(|| "gRPC router requires tokenizer".to_string())?
            .clone();
        let reasoning_parser_factory = ctx
            .reasoning_parser_factory
            .as_ref()
            .ok_or_else(|| "gRPC router requires reasoning parser factory".to_string())?
            .clone();
63
64
65
66
67
        let tool_parser_factory = ctx
            .tool_parser_factory
            .as_ref()
            .ok_or_else(|| "gRPC router requires tool parser factory".to_string())?
            .clone();
68

69
        let worker_registry = ctx.worker_registry.clone();
70
        let _policy_registry = ctx.policy_registry.clone();
Chang Su's avatar
Chang Su committed
71

72
        // Create shared components for pipeline
73
        let shared_components = Arc::new(SharedComponents {
74
75
76
77
78
            tokenizer: tokenizer.clone(),
            tool_parser_factory: tool_parser_factory.clone(),
            reasoning_parser_factory: reasoning_parser_factory.clone(),
        });

79
        // Create regular pipeline
80
81
        let pipeline = RequestPipeline::new_regular(
            worker_registry.clone(),
82
            _policy_registry.clone(),
83
84
85
86
87
88
89
            tokenizer.clone(),
            tool_parser_factory.clone(),
            reasoning_parser_factory.clone(),
            ctx.configured_tool_parser.clone(),
            ctx.configured_reasoning_parser.clone(),
        );

90
91
        // Create Harmony pipelines
        let harmony_pipeline = RequestPipeline::new_harmony(
92
            worker_registry.clone(),
93
94
95
96
97
98
            _policy_registry.clone(),
            tokenizer.clone(),
            tool_parser_factory.clone(),
            reasoning_parser_factory.clone(),
            ctx.configured_tool_parser.clone(),
            ctx.configured_reasoning_parser.clone(),
99
100
        );

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        // Extract shared dependencies for responses contexts
        let mcp_manager = ctx
            .mcp_manager
            .get()
            .ok_or_else(|| "gRPC router requires MCP manager".to_string())?
            .clone();

        // Helper closure to create responses context with a given pipeline
        let create_responses_context = |pipeline: &RequestPipeline| {
            responses::ResponsesContext::new(
                Arc::new(pipeline.clone()),
                shared_components.clone(),
                worker_registry.clone(),
                ctx.response_storage.clone(),
                ctx.conversation_storage.clone(),
                ctx.conversation_item_storage.clone(),
                mcp_manager.clone(),
            )
        };

        // Create responses contexts for both pipelines
        let responses_context = create_responses_context(&pipeline);
        let harmony_responses_context = create_responses_context(&harmony_pipeline);

125
        Ok(GrpcRouter {
126
            worker_registry,
127
            pipeline,
128
            harmony_pipeline,
129
            shared_components,
130
            responses_context,
131
            harmony_responses_context,
132
133
        })
    }
134
135
136
137

    /// Main route_chat implementation
    async fn route_chat_impl(
        &self,
138
        headers: Option<&HeaderMap>,
139
140
141
        body: &ChatCompletionRequest,
        model_id: Option<&str>,
    ) -> Response {
142
143
144
        // Choose Harmony pipeline if model indicates Harmony
        let is_harmony = HarmonyDetector::is_harmony_model(&body.model);

145
        debug!(
146
147
            "Processing chat completion request for model: {:?}, using_harmony={}",
            model_id, is_harmony
148
149
        );

150
151
152
153
154
155
156
157
        let pipeline = if is_harmony {
            &self.harmony_pipeline
        } else {
            &self.pipeline
        };

        // Use selected pipeline for ALL requests (streaming and non-streaming)
        pipeline
158
            .execute_chat(
159
                Arc::new(body.clone()),
160
161
162
163
164
                headers.cloned(),
                model_id.map(|s| s.to_string()),
                self.shared_components.clone(),
            )
            .await
165
166
    }

167
168
169
    /// Main route_generate implementation
    async fn route_generate_impl(
        &self,
170
        headers: Option<&HeaderMap>,
171
172
173
174
175
        body: &GenerateRequest,
        model_id: Option<&str>,
    ) -> Response {
        debug!("Processing generate request for model: {:?}", model_id);

176
177
178
        // Use pipeline for ALL requests (streaming and non-streaming)
        self.pipeline
            .execute_generate(
179
                Arc::new(body.clone()),
180
181
182
                headers.cloned(),
                model_id.map(|s| s.to_string()),
                self.shared_components.clone(),
183
            )
184
            .await
185
    }
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212

    /// Main route_responses implementation (pipeline-based for Harmony)
    async fn route_responses_impl(
        &self,
        _headers: Option<&HeaderMap>,
        body: &ResponsesRequest,
        model_id: Option<&str>,
    ) -> Response {
        debug!(
            "Processing Harmony responses request for model: {:?}",
            model_id
        );

        // Create HarmonyResponsesContext from existing responses context
        let harmony_ctx = HarmonyResponsesContext::new(
            Arc::new(self.harmony_pipeline.clone()),
            self.shared_components.clone(),
            self.harmony_responses_context.mcp_manager.clone(),
            self.harmony_responses_context.response_storage.clone(),
        );

        // Use serve_harmony_responses for multi-turn MCP tool orchestration
        match serve_harmony_responses(&harmony_ctx, body.clone()).await {
            Ok(response) => axum::Json(response).into_response(),
            Err(error_response) => error_response,
        }
    }
213
214
215
216
}

impl std::fmt::Debug for GrpcRouter {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217
        let stats = self.worker_registry.stats();
218
        f.debug_struct("GrpcRouter")
219
            .field("workers_count", &stats.total_workers)
220
            .finish()
221
222
223
224
225
226
227
228
229
230
    }
}

#[async_trait]
impl RouterTrait for GrpcRouter {
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

    async fn health_generate(&self, _req: Request<Body>) -> Response {
231
232
233
234
235
236
        // TODO: Implement actual generation test for gRPC
        (
            StatusCode::NOT_IMPLEMENTED,
            "Health generate not yet implemented for gRPC",
        )
            .into_response()
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
    }

    async fn get_server_info(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn get_models(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn get_model_info(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn route_generate(
        &self,
253
254
255
        headers: Option<&HeaderMap>,
        body: &GenerateRequest,
        model_id: Option<&str>,
256
    ) -> Response {
257
        self.route_generate_impl(headers, body, model_id).await
258
259
260
261
    }

    async fn route_chat(
        &self,
262
        headers: Option<&HeaderMap>,
263
        body: &ChatCompletionRequest,
264
        model_id: Option<&str>,
265
    ) -> Response {
266
        self.route_chat_impl(headers, body, model_id).await
267
268
269
270
271
    }

    async fn route_completion(
        &self,
        _headers: Option<&HeaderMap>,
272
        _body: &CompletionRequest,
273
        _model_id: Option<&str>,
274
275
276
277
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

278
279
    async fn route_responses(
        &self,
280
281
282
        headers: Option<&HeaderMap>,
        body: &ResponsesRequest,
        model_id: Option<&str>,
283
    ) -> Response {
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
        // Choose implementation based on Harmony model detection
        let is_harmony = HarmonyDetector::is_harmony_model(&body.model);

        debug!(
            "Processing responses request for model: {:?}, using_harmony={}",
            model_id, is_harmony
        );

        if is_harmony {
            // Use pipeline-based implementation for Harmony models
            self.route_responses_impl(headers, body, model_id).await
        } else {
            // Use legacy responses module for non-Harmony models
            responses::route_responses(
                &self.responses_context,
                Arc::new(body.clone()),
                headers.cloned(),
                model_id.map(|s| s.to_string()),
            )
            .await
        }
305
306
    }

307
308
309
    async fn get_response(
        &self,
        _headers: Option<&HeaderMap>,
310
        response_id: &str,
311
        _params: &ResponsesGetParams,
312
    ) -> Response {
313
        responses::get_response_impl(&self.responses_context, response_id).await
314
315
    }

316
    async fn cancel_response(&self, _headers: Option<&HeaderMap>, response_id: &str) -> Response {
317
        responses::cancel_response_impl(&self.responses_context, response_id).await
318
319
    }

320
    async fn route_embeddings(
321
322
        &self,
        _headers: Option<&HeaderMap>,
323
        _body: &EmbeddingRequest,
324
325
326
327
328
        _model_id: Option<&str>,
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

329
    async fn route_classify(
330
331
        &self,
        _headers: Option<&HeaderMap>,
332
        _body: &ClassifyRequest,
333
334
        _model_id: Option<&str>,
    ) -> Response {
335
336
337
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

338
339
340
    async fn route_rerank(
        &self,
        _headers: Option<&HeaderMap>,
341
        _body: &RerankRequest,
342
        _model_id: Option<&str>,
343
    ) -> Response {
344
345
346
347
348
349
350
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    fn router_type(&self) -> &'static str {
        "grpc"
    }
}