router.rs 7.75 KB
Newer Older
1
2
// gRPC Router Implementation

3
4
5
6
7
8
use std::sync::Arc;

use async_trait::async_trait;
use axum::{
    body::Body,
    extract::Request,
9
    http::{HeaderMap, StatusCode},
10
11
    response::{IntoResponse, Response},
};
12
use tracing::debug;
13

14
use crate::config::types::RetryConfig;
15
use crate::core::WorkerRegistry;
16
use crate::policies::PolicyRegistry;
17
18
19
20
21
22
use crate::protocols::chat::ChatCompletionRequest;
use crate::protocols::completion::CompletionRequest;
use crate::protocols::embedding::EmbeddingRequest;
use crate::protocols::generate::GenerateRequest;
use crate::protocols::rerank::RerankRequest;
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
23
use crate::reasoning_parser::ParserFactory as ReasoningParserFactory;
24
use crate::routers::RouterTrait;
25
use crate::server::AppContext;
26
use crate::tokenizer::traits::Tokenizer;
27
use crate::tool_parser::ParserFactory as ToolParserFactory;
28

29
30
31
use super::context::SharedComponents;
use super::pipeline::RequestPipeline;

32
/// gRPC router implementation for SGLang
33
#[derive(Clone)]
34
#[allow(dead_code)]
35
pub struct GrpcRouter {
36
37
    worker_registry: Arc<WorkerRegistry>,
    policy_registry: Arc<PolicyRegistry>,
38
    tokenizer: Arc<dyn Tokenizer>,
39
    reasoning_parser_factory: ReasoningParserFactory,
40
    tool_parser_factory: ToolParserFactory,
41
42
43
    dp_aware: bool,
    api_key: Option<String>,
    retry_config: RetryConfig,
44
45
    configured_reasoning_parser: Option<String>,
    configured_tool_parser: Option<String>,
46
47
    pipeline: RequestPipeline,
    shared_components: Arc<SharedComponents>,
48
}
49
50

impl GrpcRouter {
51
    /// Create a new gRPC router
52
    pub async fn new(ctx: &Arc<AppContext>) -> Result<Self, String> {
53
54
55
56
57
58
59
60
61
62
63
        // Extract necessary components from context
        let tokenizer = ctx
            .tokenizer
            .as_ref()
            .ok_or_else(|| "gRPC router requires tokenizer".to_string())?
            .clone();
        let reasoning_parser_factory = ctx
            .reasoning_parser_factory
            .as_ref()
            .ok_or_else(|| "gRPC router requires reasoning parser factory".to_string())?
            .clone();
64
65
66
67
68
        let tool_parser_factory = ctx
            .tool_parser_factory
            .as_ref()
            .ok_or_else(|| "gRPC router requires tool parser factory".to_string())?
            .clone();
69

70
71
        let worker_registry = ctx.worker_registry.clone();
        let policy_registry = ctx.policy_registry.clone();
Chang Su's avatar
Chang Su committed
72

73
        // Create shared components for pipeline
74
        let shared_components = Arc::new(SharedComponents {
75
76
77
78
79
            tokenizer: tokenizer.clone(),
            tool_parser_factory: tool_parser_factory.clone(),
            reasoning_parser_factory: reasoning_parser_factory.clone(),
        });

80
81
82
83
        // Create pipeline
        let pipeline = RequestPipeline::new_regular(
            worker_registry.clone(),
            policy_registry.clone(),
84
85
86
87
88
89
90
            tokenizer.clone(),
            tool_parser_factory.clone(),
            reasoning_parser_factory.clone(),
            ctx.configured_tool_parser.clone(),
            ctx.configured_reasoning_parser.clone(),
        );

91
        Ok(GrpcRouter {
92
93
            worker_registry,
            policy_registry,
94
95
            tokenizer,
            reasoning_parser_factory,
96
            tool_parser_factory,
97
98
99
            dp_aware: ctx.router_config.dp_aware,
            api_key: ctx.router_config.api_key.clone(),
            retry_config: ctx.router_config.effective_retry_config(),
100
101
            configured_reasoning_parser: ctx.configured_reasoning_parser.clone(),
            configured_tool_parser: ctx.configured_tool_parser.clone(),
102
103
            pipeline,
            shared_components,
104
105
        })
    }
106
107
108
109

    /// Main route_chat implementation
    async fn route_chat_impl(
        &self,
110
        headers: Option<&HeaderMap>,
111
112
113
114
115
116
117
118
        body: &ChatCompletionRequest,
        model_id: Option<&str>,
    ) -> Response {
        debug!(
            "Processing chat completion request for model: {:?}",
            model_id
        );

119
120
121
        // Use pipeline for ALL requests (streaming and non-streaming)
        self.pipeline
            .execute_chat(
122
                Arc::new(body.clone()),
123
124
125
126
127
                headers.cloned(),
                model_id.map(|s| s.to_string()),
                self.shared_components.clone(),
            )
            .await
128
129
    }

130
131
132
    /// Main route_generate implementation
    async fn route_generate_impl(
        &self,
133
        headers: Option<&HeaderMap>,
134
135
136
137
138
        body: &GenerateRequest,
        model_id: Option<&str>,
    ) -> Response {
        debug!("Processing generate request for model: {:?}", model_id);

139
140
141
        // Use pipeline for ALL requests (streaming and non-streaming)
        self.pipeline
            .execute_generate(
142
                Arc::new(body.clone()),
143
144
145
                headers.cloned(),
                model_id.map(|s| s.to_string()),
                self.shared_components.clone(),
146
            )
147
            .await
148
    }
149
150
151
152
}

impl std::fmt::Debug for GrpcRouter {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153
        let stats = self.worker_registry.stats();
154
        f.debug_struct("GrpcRouter")
155
            .field("workers_count", &stats.total_workers)
156
157
            .field("dp_aware", &self.dp_aware)
            .finish()
158
159
160
161
162
163
164
165
166
167
    }
}

#[async_trait]
impl RouterTrait for GrpcRouter {
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

    async fn health_generate(&self, _req: Request<Body>) -> Response {
168
169
170
171
172
173
        // TODO: Implement actual generation test for gRPC
        (
            StatusCode::NOT_IMPLEMENTED,
            "Health generate not yet implemented for gRPC",
        )
            .into_response()
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    }

    async fn get_server_info(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn get_models(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn get_model_info(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn route_generate(
        &self,
190
191
192
        headers: Option<&HeaderMap>,
        body: &GenerateRequest,
        model_id: Option<&str>,
193
    ) -> Response {
194
        self.route_generate_impl(headers, body, model_id).await
195
196
197
198
    }

    async fn route_chat(
        &self,
199
        headers: Option<&HeaderMap>,
200
        body: &ChatCompletionRequest,
201
        model_id: Option<&str>,
202
    ) -> Response {
203
        self.route_chat_impl(headers, body, model_id).await
204
205
206
207
208
    }

    async fn route_completion(
        &self,
        _headers: Option<&HeaderMap>,
209
        _body: &CompletionRequest,
210
        _model_id: Option<&str>,
211
212
213
214
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

215
216
217
    async fn route_responses(
        &self,
        _headers: Option<&HeaderMap>,
218
        _body: &ResponsesRequest,
219
        _model_id: Option<&str>,
220
221
222
223
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

224
225
226
227
    async fn get_response(
        &self,
        _headers: Option<&HeaderMap>,
        _response_id: &str,
228
        _params: &ResponsesGetParams,
229
    ) -> Response {
230
231
232
233
234
235
236
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

237
238
239
    async fn route_embeddings(
        &self,
        _headers: Option<&HeaderMap>,
240
        _body: &EmbeddingRequest,
241
242
        _model_id: Option<&str>,
    ) -> Response {
243
244
245
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

246
247
248
    async fn route_rerank(
        &self,
        _headers: Option<&HeaderMap>,
249
        _body: &RerankRequest,
250
        _model_id: Option<&str>,
251
    ) -> Response {
252
253
254
255
256
257
258
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    fn router_type(&self) -> &'static str {
        "grpc"
    }
}