pd_router.rs 7.58 KB
Newer Older
1
2
// PD (Prefill-Decode) gRPC Router Implementation

3
4
use std::sync::Arc;

5
6
7
8
use async_trait::async_trait;
use axum::{
    body::Body,
    extract::Request,
9
    http::{HeaderMap, StatusCode},
10
11
    response::{IntoResponse, Response},
};
12
use tracing::debug;
13

14
15
use super::{context::SharedComponents, pipeline::RequestPipeline};
use crate::{
16
    app_context::AppContext,
17
18
19
    core::{ConnectionMode, WorkerRegistry, WorkerType},
    protocols::{
        chat::ChatCompletionRequest,
20
        classify::ClassifyRequest,
21
22
23
24
25
26
27
28
        completion::CompletionRequest,
        embedding::EmbeddingRequest,
        generate::GenerateRequest,
        rerank::RerankRequest,
        responses::{ResponsesGetParams, ResponsesRequest},
    },
    routers::RouterTrait,
};
29

30
/// gRPC PD (Prefill-Decode) router implementation for SGLang
31
#[derive(Clone)]
32
pub struct GrpcPDRouter {
33
    worker_registry: Arc<WorkerRegistry>,
34
35
    pipeline: RequestPipeline,
    shared_components: Arc<SharedComponents>,
36
}
37
38

impl GrpcPDRouter {
39
    /// Create a new gRPC PD router
40
    pub async fn new(ctx: &Arc<AppContext>) -> Result<Self, String> {
41
42
43
44
        // Get registries from context
        let worker_registry = ctx.worker_registry.clone();
        let policy_registry = ctx.policy_registry.clone();

45
46
47
48
49
50
51
52
53
54
55
        // Extract necessary components from context
        let tokenizer = ctx
            .tokenizer
            .as_ref()
            .ok_or_else(|| "gRPC PD router requires tokenizer".to_string())?
            .clone();
        let reasoning_parser_factory = ctx
            .reasoning_parser_factory
            .as_ref()
            .ok_or_else(|| "gRPC PD router requires reasoning parser factory".to_string())?
            .clone();
56
57
58
59
60
        let tool_parser_factory = ctx
            .tool_parser_factory
            .as_ref()
            .ok_or_else(|| "gRPC PD router requires tool parser factory".to_string())?
            .clone();
61

62
        // Create shared components for pipeline
63
        let shared_components = Arc::new(SharedComponents {
64
65
66
67
68
            tokenizer: tokenizer.clone(),
            tool_parser_factory: tool_parser_factory.clone(),
            reasoning_parser_factory: reasoning_parser_factory.clone(),
        });

69
70
71
72
        // Create PD pipeline
        let pipeline = RequestPipeline::new_pd(
            worker_registry.clone(),
            policy_registry.clone(),
73
74
75
76
77
78
79
            tokenizer.clone(),
            tool_parser_factory.clone(),
            reasoning_parser_factory.clone(),
            ctx.configured_tool_parser.clone(),
            ctx.configured_reasoning_parser.clone(),
        );

80
        Ok(GrpcPDRouter {
81
            worker_registry,
82
83
            pipeline,
            shared_components,
84
85
        })
    }
86
87
88
89

    /// Main route_generate implementation with PD dual dispatch
    async fn route_generate_impl(
        &self,
90
        headers: Option<&HeaderMap>,
91
92
93
94
95
96
97
98
        body: &GenerateRequest,
        model_id: Option<&str>,
    ) -> Response {
        debug!(
            "Processing generate request for model: {:?} (PD mode)",
            model_id
        );

99
100
101
        // Use pipeline for ALL requests (streaming and non-streaming)
        self.pipeline
            .execute_generate(
102
                Arc::new(body.clone()),
103
104
105
                headers.cloned(),
                model_id.map(|s| s.to_string()),
                self.shared_components.clone(),
106
107
108
109
110
111
112
            )
            .await
    }

    /// Main route_chat implementation with PD dual dispatch
    async fn route_chat_impl(
        &self,
113
        headers: Option<&HeaderMap>,
114
115
116
117
118
119
120
121
        body: &ChatCompletionRequest,
        model_id: Option<&str>,
    ) -> Response {
        debug!(
            "Processing chat completion request for model: {:?} (PD mode)",
            model_id
        );

122
123
124
        // Use pipeline for ALL requests (streaming and non-streaming)
        self.pipeline
            .execute_chat(
125
                Arc::new(body.clone()),
126
127
128
129
                headers.cloned(),
                model_id.map(|s| s.to_string()),
                self.shared_components.clone(),
            )
130
131
            .await
    }
132
133
134
135
}

impl std::fmt::Debug for GrpcPDRouter {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136
137
138
139
140
        let prefill_workers = self.worker_registry.get_workers_filtered(
            None,
            Some(WorkerType::Prefill {
                bootstrap_port: None,
            }),
141
            Some(ConnectionMode::Grpc { port: None }),
142
143
144
145
146
            false,
        );
        let decode_workers = self.worker_registry.get_workers_filtered(
            None,
            Some(WorkerType::Decode),
147
            Some(ConnectionMode::Grpc { port: None }),
148
149
            false,
        );
150
        f.debug_struct("GrpcPDRouter")
151
152
            .field("prefill_workers_count", &prefill_workers.len())
            .field("decode_workers_count", &decode_workers.len())
153
            .finish()
154
155
156
157
158
159
160
161
162
163
    }
}

#[async_trait]
impl RouterTrait for GrpcPDRouter {
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

    async fn health_generate(&self, _req: Request<Body>) -> Response {
164
165
166
167
168
169
        // TODO: Implement actual generation test for gRPC PD mode
        (
            StatusCode::NOT_IMPLEMENTED,
            "Health generate not yet implemented for gRPC PD",
        )
            .into_response()
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    }

    async fn get_server_info(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn get_models(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn get_model_info(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn route_generate(
        &self,
186
187
188
        headers: Option<&HeaderMap>,
        body: &GenerateRequest,
        model_id: Option<&str>,
189
    ) -> Response {
190
        self.route_generate_impl(headers, body, model_id).await
191
192
193
194
    }

    async fn route_chat(
        &self,
195
196
197
        headers: Option<&HeaderMap>,
        body: &ChatCompletionRequest,
        model_id: Option<&str>,
198
    ) -> Response {
199
        self.route_chat_impl(headers, body, model_id).await
200
201
202
203
204
    }

    async fn route_completion(
        &self,
        _headers: Option<&HeaderMap>,
205
        _body: &CompletionRequest,
206
        _model_id: Option<&str>,
207
208
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
209
210
211
212
213
    }

    async fn route_responses(
        &self,
        _headers: Option<&HeaderMap>,
214
        _body: &ResponsesRequest,
215
        _model_id: Option<&str>,
216
217
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
218
219
    }

220
221
222
223
    async fn get_response(
        &self,
        _headers: Option<&HeaderMap>,
        _response_id: &str,
224
        _params: &ResponsesGetParams,
225
    ) -> Response {
226
227
228
229
230
231
232
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

233
    async fn route_embeddings(
234
235
        &self,
        _headers: Option<&HeaderMap>,
236
        _body: &EmbeddingRequest,
237
238
239
240
241
        _model_id: Option<&str>,
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

242
    async fn route_classify(
243
244
        &self,
        _headers: Option<&HeaderMap>,
245
        _body: &ClassifyRequest,
246
247
        _model_id: Option<&str>,
    ) -> Response {
248
249
250
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

251
252
253
    async fn route_rerank(
        &self,
        _headers: Option<&HeaderMap>,
254
        _body: &RerankRequest,
255
        _model_id: Option<&str>,
256
    ) -> Response {
257
258
259
260
261
262
263
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    fn router_type(&self) -> &'static str {
        "grpc_pd"
    }
}