pd_router.rs 7.47 KB
Newer Older
1
2
use std::sync::Arc;

3
4
5
6
use async_trait::async_trait;
use axum::{
    body::Body,
    extract::Request,
7
    http::{HeaderMap, StatusCode},
8
9
    response::{IntoResponse, Response},
};
10
use tracing::debug;
11

12
13
use super::{context::SharedComponents, pipeline::RequestPipeline};
use crate::{
14
    app_context::AppContext,
15
16
17
    core::{ConnectionMode, WorkerRegistry, WorkerType},
    protocols::{
        chat::ChatCompletionRequest,
18
        classify::ClassifyRequest,
19
20
21
22
23
24
25
26
        completion::CompletionRequest,
        embedding::EmbeddingRequest,
        generate::GenerateRequest,
        rerank::RerankRequest,
        responses::{ResponsesGetParams, ResponsesRequest},
    },
    routers::RouterTrait,
};
27

28
/// gRPC PD (Prefill-Decode) router implementation for SGLang
29
#[derive(Clone)]
30
pub struct GrpcPDRouter {
31
    worker_registry: Arc<WorkerRegistry>,
32
33
    pipeline: RequestPipeline,
    shared_components: Arc<SharedComponents>,
34
}
35
36

impl GrpcPDRouter {
37
    /// Create a new gRPC PD router
38
    pub async fn new(ctx: &Arc<AppContext>) -> Result<Self, String> {
39
40
41
42
        // Get registries from context
        let worker_registry = ctx.worker_registry.clone();
        let policy_registry = ctx.policy_registry.clone();

43
44
45
46
47
48
49
50
51
52
53
        // Extract necessary components from context
        let tokenizer = ctx
            .tokenizer
            .as_ref()
            .ok_or_else(|| "gRPC PD router requires tokenizer".to_string())?
            .clone();
        let reasoning_parser_factory = ctx
            .reasoning_parser_factory
            .as_ref()
            .ok_or_else(|| "gRPC PD router requires reasoning parser factory".to_string())?
            .clone();
54
55
56
57
58
        let tool_parser_factory = ctx
            .tool_parser_factory
            .as_ref()
            .ok_or_else(|| "gRPC PD router requires tool parser factory".to_string())?
            .clone();
59

60
        // Create shared components for pipeline
61
        let shared_components = Arc::new(SharedComponents {
62
63
64
65
66
            tokenizer: tokenizer.clone(),
            tool_parser_factory: tool_parser_factory.clone(),
            reasoning_parser_factory: reasoning_parser_factory.clone(),
        });

67
68
69
70
        // Create PD pipeline
        let pipeline = RequestPipeline::new_pd(
            worker_registry.clone(),
            policy_registry.clone(),
71
72
73
74
75
76
77
            tokenizer.clone(),
            tool_parser_factory.clone(),
            reasoning_parser_factory.clone(),
            ctx.configured_tool_parser.clone(),
            ctx.configured_reasoning_parser.clone(),
        );

78
        Ok(GrpcPDRouter {
79
            worker_registry,
80
81
            pipeline,
            shared_components,
82
83
        })
    }
84
85
86
87

    /// Main route_generate implementation with PD dual dispatch
    async fn route_generate_impl(
        &self,
88
        headers: Option<&HeaderMap>,
89
90
91
92
93
94
95
96
        body: &GenerateRequest,
        model_id: Option<&str>,
    ) -> Response {
        debug!(
            "Processing generate request for model: {:?} (PD mode)",
            model_id
        );

97
98
99
        // Use pipeline for ALL requests (streaming and non-streaming)
        self.pipeline
            .execute_generate(
100
                Arc::new(body.clone()),
101
102
103
                headers.cloned(),
                model_id.map(|s| s.to_string()),
                self.shared_components.clone(),
104
105
106
107
108
109
110
            )
            .await
    }

    /// Main route_chat implementation with PD dual dispatch
    async fn route_chat_impl(
        &self,
111
        headers: Option<&HeaderMap>,
112
113
114
115
116
117
118
119
        body: &ChatCompletionRequest,
        model_id: Option<&str>,
    ) -> Response {
        debug!(
            "Processing chat completion request for model: {:?} (PD mode)",
            model_id
        );

120
121
122
        // Use pipeline for ALL requests (streaming and non-streaming)
        self.pipeline
            .execute_chat(
123
                Arc::new(body.clone()),
124
125
126
127
                headers.cloned(),
                model_id.map(|s| s.to_string()),
                self.shared_components.clone(),
            )
128
129
            .await
    }
130
131
132
133
}

impl std::fmt::Debug for GrpcPDRouter {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134
135
136
137
138
        let prefill_workers = self.worker_registry.get_workers_filtered(
            None,
            Some(WorkerType::Prefill {
                bootstrap_port: None,
            }),
139
            Some(ConnectionMode::Grpc { port: None }),
140
141
142
143
144
            false,
        );
        let decode_workers = self.worker_registry.get_workers_filtered(
            None,
            Some(WorkerType::Decode),
145
            Some(ConnectionMode::Grpc { port: None }),
146
147
            false,
        );
148
        f.debug_struct("GrpcPDRouter")
149
150
            .field("prefill_workers_count", &prefill_workers.len())
            .field("decode_workers_count", &decode_workers.len())
151
            .finish()
152
153
154
155
156
157
158
159
160
161
    }
}

#[async_trait]
impl RouterTrait for GrpcPDRouter {
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

    async fn health_generate(&self, _req: Request<Body>) -> Response {
162
163
164
165
166
        (
            StatusCode::NOT_IMPLEMENTED,
            "Health generate not yet implemented for gRPC PD",
        )
            .into_response()
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    }

    async fn get_server_info(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn get_models(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn get_model_info(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn route_generate(
        &self,
183
184
185
        headers: Option<&HeaderMap>,
        body: &GenerateRequest,
        model_id: Option<&str>,
186
    ) -> Response {
187
        self.route_generate_impl(headers, body, model_id).await
188
189
190
191
    }

    async fn route_chat(
        &self,
192
193
194
        headers: Option<&HeaderMap>,
        body: &ChatCompletionRequest,
        model_id: Option<&str>,
195
    ) -> Response {
196
        self.route_chat_impl(headers, body, model_id).await
197
198
199
200
201
    }

    async fn route_completion(
        &self,
        _headers: Option<&HeaderMap>,
202
        _body: &CompletionRequest,
203
        _model_id: Option<&str>,
204
205
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
206
207
208
209
210
    }

    async fn route_responses(
        &self,
        _headers: Option<&HeaderMap>,
211
        _body: &ResponsesRequest,
212
        _model_id: Option<&str>,
213
214
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
215
216
    }

217
218
219
220
    async fn get_response(
        &self,
        _headers: Option<&HeaderMap>,
        _response_id: &str,
221
        _params: &ResponsesGetParams,
222
    ) -> Response {
223
224
225
226
227
228
229
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

230
    async fn route_embeddings(
231
232
        &self,
        _headers: Option<&HeaderMap>,
233
        _body: &EmbeddingRequest,
234
235
236
237
238
        _model_id: Option<&str>,
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

239
    async fn route_classify(
240
241
        &self,
        _headers: Option<&HeaderMap>,
242
        _body: &ClassifyRequest,
243
244
        _model_id: Option<&str>,
    ) -> Response {
245
246
247
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

248
249
250
    async fn route_rerank(
        &self,
        _headers: Option<&HeaderMap>,
251
        _body: &RerankRequest,
252
        _model_id: Option<&str>,
253
    ) -> Response {
254
255
256
257
258
259
260
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    fn router_type(&self) -> &'static str {
        "grpc_pd"
    }
}