pd_router.rs 8.87 KB
Newer Older
1
2
// PD (Prefill-Decode) gRPC Router Implementation

3
use crate::config::types::RetryConfig;
4
use crate::core::{ConnectionMode, WorkerRegistry, WorkerType};
5
use crate::policies::PolicyRegistry;
6
use crate::protocols::spec::{
7
8
    ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
    ResponsesGetParams, ResponsesRequest,
9
};
10
use crate::reasoning_parser::ReasoningParserFactory;
11
use crate::routers::RouterTrait;
12
use crate::server::AppContext;
13
use crate::tokenizer::traits::Tokenizer;
14
use crate::tool_parser::ToolParserFactory;
15
16
17
18
use async_trait::async_trait;
use axum::{
    body::Body,
    extract::Request,
19
    http::{HeaderMap, StatusCode},
20
21
    response::{IntoResponse, Response},
};
22
use std::sync::Arc;
23
24

use tracing::debug;
25

26
/// gRPC PD (Prefill-Decode) router implementation for SGLang
27
#[derive(Clone)]
28
29
#[allow(dead_code)] // Fields will be used once implementation is complete
pub struct GrpcPDRouter {
30
31
    worker_registry: Arc<WorkerRegistry>,
    policy_registry: Arc<PolicyRegistry>,
32
    tokenizer: Arc<dyn Tokenizer>,
33
    reasoning_parser_factory: ReasoningParserFactory,
34
    tool_parser_factory: ToolParserFactory,
35
36
37
    dp_aware: bool,
    api_key: Option<String>,
    retry_config: RetryConfig,
38
39
    configured_reasoning_parser: Option<String>,
    configured_tool_parser: Option<String>,
40
41
    pipeline: super::pipeline::ChatCompletionPipeline,
    shared_components: Arc<super::context::SharedComponents>,
42
}
43
44

impl GrpcPDRouter {
45
    /// Create a new gRPC PD router
46
    pub async fn new(ctx: &Arc<AppContext>) -> Result<Self, String> {
47
48
49
50
        // Get registries from context
        let worker_registry = ctx.worker_registry.clone();
        let policy_registry = ctx.policy_registry.clone();

51
52
53
54
55
56
57
58
59
60
61
        // Extract necessary components from context
        let tokenizer = ctx
            .tokenizer
            .as_ref()
            .ok_or_else(|| "gRPC PD router requires tokenizer".to_string())?
            .clone();
        let reasoning_parser_factory = ctx
            .reasoning_parser_factory
            .as_ref()
            .ok_or_else(|| "gRPC PD router requires reasoning parser factory".to_string())?
            .clone();
62
63
64
65
66
        let tool_parser_factory = ctx
            .tool_parser_factory
            .as_ref()
            .ok_or_else(|| "gRPC PD router requires tool parser factory".to_string())?
            .clone();
67

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        // Create shared components for pipeline
        let shared_components = Arc::new(super::context::SharedComponents {
            tokenizer: tokenizer.clone(),
            tool_parser_factory: tool_parser_factory.clone(),
            reasoning_parser_factory: reasoning_parser_factory.clone(),
        });

        // Create response processor
        let processor = super::processing::ResponseProcessor::new(
            tokenizer.clone(),
            tool_parser_factory.clone(),
            reasoning_parser_factory.clone(),
            ctx.configured_tool_parser.clone(),
            ctx.configured_reasoning_parser.clone(),
        );

        // Create streaming processor
        let streaming_processor = Arc::new(super::streaming::StreamingProcessor::new(
            tokenizer.clone(),
            tool_parser_factory.clone(),
            reasoning_parser_factory.clone(),
            ctx.configured_tool_parser.clone(),
            ctx.configured_reasoning_parser.clone(),
        ));

        // Create PD pipeline
        let pipeline = super::pipeline::ChatCompletionPipeline::new_pd(
            worker_registry.clone(),
            policy_registry.clone(),
            processor,
            streaming_processor,
        );

101
        Ok(GrpcPDRouter {
102
103
            worker_registry,
            policy_registry,
104
105
            tokenizer,
            reasoning_parser_factory,
106
            tool_parser_factory,
107
108
109
            dp_aware: ctx.router_config.dp_aware,
            api_key: ctx.router_config.api_key.clone(),
            retry_config: ctx.router_config.effective_retry_config(),
110
111
            configured_reasoning_parser: ctx.configured_reasoning_parser.clone(),
            configured_tool_parser: ctx.configured_tool_parser.clone(),
112
113
            pipeline,
            shared_components,
114
115
        })
    }
116
117
118
119

    /// Main route_generate implementation with PD dual dispatch
    async fn route_generate_impl(
        &self,
120
        headers: Option<&HeaderMap>,
121
122
123
124
125
126
127
128
        body: &GenerateRequest,
        model_id: Option<&str>,
    ) -> Response {
        debug!(
            "Processing generate request for model: {:?} (PD mode)",
            model_id
        );

129
130
131
        // Use pipeline for ALL requests (streaming and non-streaming)
        self.pipeline
            .execute_generate(
132
                Arc::new(body.clone()),
133
134
135
                headers.cloned(),
                model_id.map(|s| s.to_string()),
                self.shared_components.clone(),
136
137
138
139
140
141
142
            )
            .await
    }

    /// Main route_chat implementation with PD dual dispatch
    async fn route_chat_impl(
        &self,
143
        headers: Option<&HeaderMap>,
144
145
146
147
148
149
150
151
        body: &ChatCompletionRequest,
        model_id: Option<&str>,
    ) -> Response {
        debug!(
            "Processing chat completion request for model: {:?} (PD mode)",
            model_id
        );

152
153
154
        // Use pipeline for ALL requests (streaming and non-streaming)
        self.pipeline
            .execute_chat(
155
                Arc::new(body.clone()),
156
157
158
159
                headers.cloned(),
                model_id.map(|s| s.to_string()),
                self.shared_components.clone(),
            )
160
161
            .await
    }
162
163
164
165
}

impl std::fmt::Debug for GrpcPDRouter {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166
167
168
169
170
        let prefill_workers = self.worker_registry.get_workers_filtered(
            None,
            Some(WorkerType::Prefill {
                bootstrap_port: None,
            }),
171
            Some(ConnectionMode::Grpc { port: None }),
172
173
174
175
176
            false,
        );
        let decode_workers = self.worker_registry.get_workers_filtered(
            None,
            Some(WorkerType::Decode),
177
            Some(ConnectionMode::Grpc { port: None }),
178
179
            false,
        );
180
        f.debug_struct("GrpcPDRouter")
181
182
            .field("prefill_workers_count", &prefill_workers.len())
            .field("decode_workers_count", &decode_workers.len())
183
184
            .field("dp_aware", &self.dp_aware)
            .finish()
185
186
187
188
189
190
191
192
193
194
    }
}

#[async_trait]
impl RouterTrait for GrpcPDRouter {
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

    async fn health_generate(&self, _req: Request<Body>) -> Response {
195
196
197
198
199
200
        // TODO: Implement actual generation test for gRPC PD mode
        (
            StatusCode::NOT_IMPLEMENTED,
            "Health generate not yet implemented for gRPC PD",
        )
            .into_response()
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    }

    async fn get_server_info(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn get_models(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn get_model_info(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn route_generate(
        &self,
217
218
219
        headers: Option<&HeaderMap>,
        body: &GenerateRequest,
        model_id: Option<&str>,
220
    ) -> Response {
221
        self.route_generate_impl(headers, body, model_id).await
222
223
224
225
    }

    async fn route_chat(
        &self,
226
227
228
        headers: Option<&HeaderMap>,
        body: &ChatCompletionRequest,
        model_id: Option<&str>,
229
    ) -> Response {
230
        self.route_chat_impl(headers, body, model_id).await
231
232
233
234
235
    }

    async fn route_completion(
        &self,
        _headers: Option<&HeaderMap>,
236
        _body: &CompletionRequest,
237
        _model_id: Option<&str>,
238
239
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
240
241
242
243
244
    }

    async fn route_responses(
        &self,
        _headers: Option<&HeaderMap>,
245
        _body: &ResponsesRequest,
246
        _model_id: Option<&str>,
247
248
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
249
250
    }

251
252
253
254
    async fn get_response(
        &self,
        _headers: Option<&HeaderMap>,
        _response_id: &str,
255
        _params: &ResponsesGetParams,
256
    ) -> Response {
257
258
259
260
261
262
263
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

264
265
266
    async fn route_embeddings(
        &self,
        _headers: Option<&HeaderMap>,
267
        _body: &EmbeddingRequest,
268
269
        _model_id: Option<&str>,
    ) -> Response {
270
271
272
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

273
274
275
    async fn route_rerank(
        &self,
        _headers: Option<&HeaderMap>,
276
        _body: &RerankRequest,
277
        _model_id: Option<&str>,
278
    ) -> Response {
279
280
281
282
283
284
285
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    fn router_type(&self) -> &'static str {
        "grpc_pd"
    }
}