router.rs 8.09 KB
Newer Older
1
2
// gRPC Router Implementation

3
4
5
6
7
8
use std::sync::Arc;

use async_trait::async_trait;
use axum::{
    body::Body,
    extract::Request,
9
    http::{HeaderMap, StatusCode},
10
11
    response::{IntoResponse, Response},
};
12
use tracing::debug;
13

14
use crate::config::types::RetryConfig;
15
use crate::core::WorkerRegistry;
16
use crate::policies::PolicyRegistry;
17
use crate::protocols::spec::{
18
19
    ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
    ResponsesGetParams, ResponsesRequest,
20
};
21
use crate::reasoning_parser::ReasoningParserFactory;
22
use crate::routers::RouterTrait;
23
use crate::server::AppContext;
24
use crate::tokenizer::traits::Tokenizer;
25
use crate::tool_parser::ToolParserFactory;
26

27
/// gRPC router implementation for SGLang
28
#[derive(Clone)]
29
#[allow(dead_code)]
30
pub struct GrpcRouter {
31
32
    worker_registry: Arc<WorkerRegistry>,
    policy_registry: Arc<PolicyRegistry>,
33
    tokenizer: Arc<dyn Tokenizer>,
34
    reasoning_parser_factory: ReasoningParserFactory,
35
    tool_parser_factory: ToolParserFactory,
36
37
38
    dp_aware: bool,
    api_key: Option<String>,
    retry_config: RetryConfig,
39
40
    configured_reasoning_parser: Option<String>,
    configured_tool_parser: Option<String>,
41
42
    pipeline: super::pipeline::ChatCompletionPipeline,
    shared_components: Arc<super::context::SharedComponents>,
43
}
44
45

impl GrpcRouter {
46
    /// Create a new gRPC router
47
    pub async fn new(ctx: &Arc<AppContext>) -> Result<Self, String> {
48
49
50
51
52
53
54
55
56
57
58
        // Extract necessary components from context
        let tokenizer = ctx
            .tokenizer
            .as_ref()
            .ok_or_else(|| "gRPC router requires tokenizer".to_string())?
            .clone();
        let reasoning_parser_factory = ctx
            .reasoning_parser_factory
            .as_ref()
            .ok_or_else(|| "gRPC router requires reasoning parser factory".to_string())?
            .clone();
59
60
61
62
63
        let tool_parser_factory = ctx
            .tool_parser_factory
            .as_ref()
            .ok_or_else(|| "gRPC router requires tool parser factory".to_string())?
            .clone();
64

65
66
        let worker_registry = ctx.worker_registry.clone();
        let policy_registry = ctx.policy_registry.clone();
Chang Su's avatar
Chang Su committed
67

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        // Create shared components for pipeline
        let shared_components = Arc::new(super::context::SharedComponents {
            tokenizer: tokenizer.clone(),
            tool_parser_factory: tool_parser_factory.clone(),
            reasoning_parser_factory: reasoning_parser_factory.clone(),
        });

        // Create response processor
        let processor = super::processing::ResponseProcessor::new(
            tokenizer.clone(),
            tool_parser_factory.clone(),
            reasoning_parser_factory.clone(),
            ctx.configured_tool_parser.clone(),
            ctx.configured_reasoning_parser.clone(),
        );

        // Create streaming processor
        let streaming_processor = Arc::new(super::streaming::StreamingProcessor::new(
            tokenizer.clone(),
            tool_parser_factory.clone(),
            reasoning_parser_factory.clone(),
            ctx.configured_tool_parser.clone(),
            ctx.configured_reasoning_parser.clone(),
        ));

        // Create pipeline
        let pipeline = super::pipeline::ChatCompletionPipeline::new_regular(
            worker_registry.clone(),
            policy_registry.clone(),
            processor,
            streaming_processor,
        );

101
        Ok(GrpcRouter {
102
103
            worker_registry,
            policy_registry,
104
105
            tokenizer,
            reasoning_parser_factory,
106
            tool_parser_factory,
107
108
109
            dp_aware: ctx.router_config.dp_aware,
            api_key: ctx.router_config.api_key.clone(),
            retry_config: ctx.router_config.effective_retry_config(),
110
111
            configured_reasoning_parser: ctx.configured_reasoning_parser.clone(),
            configured_tool_parser: ctx.configured_tool_parser.clone(),
112
113
            pipeline,
            shared_components,
114
115
        })
    }
116
117
118
119

    /// Main route_chat implementation
    async fn route_chat_impl(
        &self,
120
        headers: Option<&HeaderMap>,
121
122
123
124
125
126
127
128
        body: &ChatCompletionRequest,
        model_id: Option<&str>,
    ) -> Response {
        debug!(
            "Processing chat completion request for model: {:?}",
            model_id
        );

129
130
131
        // Use pipeline for ALL requests (streaming and non-streaming)
        self.pipeline
            .execute_chat(
132
                Arc::new(body.clone()),
133
134
135
136
137
                headers.cloned(),
                model_id.map(|s| s.to_string()),
                self.shared_components.clone(),
            )
            .await
138
139
    }

140
141
142
    /// Main route_generate implementation
    async fn route_generate_impl(
        &self,
143
        headers: Option<&HeaderMap>,
144
145
146
147
148
        body: &GenerateRequest,
        model_id: Option<&str>,
    ) -> Response {
        debug!("Processing generate request for model: {:?}", model_id);

149
150
151
        // Use pipeline for ALL requests (streaming and non-streaming)
        self.pipeline
            .execute_generate(
152
                Arc::new(body.clone()),
153
154
155
                headers.cloned(),
                model_id.map(|s| s.to_string()),
                self.shared_components.clone(),
156
            )
157
            .await
158
    }
159
160
161
162
}

impl std::fmt::Debug for GrpcRouter {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163
        let stats = self.worker_registry.stats();
164
        f.debug_struct("GrpcRouter")
165
            .field("workers_count", &stats.total_workers)
166
167
            .field("dp_aware", &self.dp_aware)
            .finish()
168
169
170
171
172
173
174
175
176
177
    }
}

#[async_trait]
impl RouterTrait for GrpcRouter {
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

    async fn health_generate(&self, _req: Request<Body>) -> Response {
178
179
180
181
182
183
        // TODO: Implement actual generation test for gRPC
        (
            StatusCode::NOT_IMPLEMENTED,
            "Health generate not yet implemented for gRPC",
        )
            .into_response()
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    }

    async fn get_server_info(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn get_models(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn get_model_info(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn route_generate(
        &self,
200
201
202
        headers: Option<&HeaderMap>,
        body: &GenerateRequest,
        model_id: Option<&str>,
203
    ) -> Response {
204
        self.route_generate_impl(headers, body, model_id).await
205
206
207
208
    }

    async fn route_chat(
        &self,
209
        headers: Option<&HeaderMap>,
210
        body: &ChatCompletionRequest,
211
        model_id: Option<&str>,
212
    ) -> Response {
213
        self.route_chat_impl(headers, body, model_id).await
214
215
216
217
218
    }

    async fn route_completion(
        &self,
        _headers: Option<&HeaderMap>,
219
        _body: &CompletionRequest,
220
        _model_id: Option<&str>,
221
222
223
224
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

225
226
227
    async fn route_responses(
        &self,
        _headers: Option<&HeaderMap>,
228
        _body: &ResponsesRequest,
229
        _model_id: Option<&str>,
230
231
232
233
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

234
235
236
237
    async fn get_response(
        &self,
        _headers: Option<&HeaderMap>,
        _response_id: &str,
238
        _params: &ResponsesGetParams,
239
    ) -> Response {
240
241
242
243
244
245
246
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

247
248
249
    async fn route_embeddings(
        &self,
        _headers: Option<&HeaderMap>,
250
        _body: &EmbeddingRequest,
251
252
        _model_id: Option<&str>,
    ) -> Response {
253
254
255
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

256
257
258
    async fn route_rerank(
        &self,
        _headers: Option<&HeaderMap>,
259
        _body: &RerankRequest,
260
        _model_id: Option<&str>,
261
    ) -> Response {
262
263
264
265
266
267
268
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    fn router_type(&self) -> &'static str {
        "grpc"
    }
}