pd_router.rs 6.99 KB
Newer Older
1
2
// PD (Prefill-Decode) gRPC Router Implementation

3
use crate::config::types::RetryConfig;
4
use crate::core::{WorkerRegistry, WorkerType};
5
use crate::metrics::RouterMetrics;
6
use crate::policies::PolicyRegistry;
7
use crate::reasoning_parser::ReasoningParserFactory;
8
use crate::routers::RouterTrait;
9
use crate::tokenizer::traits::Tokenizer;
10
use crate::tool_parser::ToolParserFactory;
11
12
13
14
15
16
17
use async_trait::async_trait;
use axum::{
    body::Body,
    extract::Request,
    http::{HeaderMap, StatusCode},
    response::{IntoResponse, Response},
};
18
use std::sync::Arc;
19
use tracing::info;
20

21
22
23
/// gRPC PD (Prefill-Decode) router implementation for SGLang
#[allow(dead_code)] // Fields will be used once implementation is complete
pub struct GrpcPDRouter {
24
25
    worker_registry: Arc<WorkerRegistry>,
    policy_registry: Arc<PolicyRegistry>,
26
    tokenizer: Arc<dyn Tokenizer>,
27
    reasoning_parser_factory: ReasoningParserFactory,
28
    tool_parser_factory: ToolParserFactory,
29

30
31
32
33
    dp_aware: bool,
    api_key: Option<String>,
    retry_config: RetryConfig,
}
34
35

impl GrpcPDRouter {
36
    /// Create a new gRPC PD router
37
    pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
38
39
40
41
        // Get registries from context
        let worker_registry = ctx.worker_registry.clone();
        let policy_registry = ctx.policy_registry.clone();

42
43
44
45
46
47
48
49
50
51
52
        // Extract necessary components from context
        let tokenizer = ctx
            .tokenizer
            .as_ref()
            .ok_or_else(|| "gRPC PD router requires tokenizer".to_string())?
            .clone();
        let reasoning_parser_factory = ctx
            .reasoning_parser_factory
            .as_ref()
            .ok_or_else(|| "gRPC PD router requires reasoning parser factory".to_string())?
            .clone();
53
54
55
56
57
        let tool_parser_factory = ctx
            .tool_parser_factory
            .as_ref()
            .ok_or_else(|| "gRPC PD router requires tool parser factory".to_string())?
            .clone();
58

59
        // Get prefill and decode workers from registry - they should have been created by WorkerManager
60
61
62
63
64
65
66
67
        let prefill_workers = worker_registry.get_workers_filtered(
            None, // any model
            Some(WorkerType::Prefill {
                bootstrap_port: None,
            }),
            Some(crate::core::ConnectionMode::Grpc { port: None }),
            false, // include unhealthy workers during initialization
        );
68

69
70
71
72
73
74
        let decode_workers = worker_registry.get_workers_filtered(
            None, // any model
            Some(WorkerType::Decode),
            Some(crate::core::ConnectionMode::Grpc { port: None }),
            false, // include unhealthy workers during initialization
        );
75
76
77
78
79
80
81
82

        // Update metrics
        RouterMetrics::set_active_workers(prefill_workers.len() + decode_workers.len());
        info!(
            "gRPC PD router found {} prefill and {} decode workers in registry",
            prefill_workers.len(),
            decode_workers.len()
        );
83

84
        // No need for local health checkers - WorkerRegistry handles health checking
85
86

        Ok(GrpcPDRouter {
87
88
            worker_registry,
            policy_registry,
89
90
            tokenizer,
            reasoning_parser_factory,
91
            tool_parser_factory,
92
93
94
            dp_aware: ctx.router_config.dp_aware,
            api_key: ctx.router_config.api_key.clone(),
            retry_config: ctx.router_config.effective_retry_config(),
95
96
97
98
99
100
        })
    }
}

impl std::fmt::Debug for GrpcPDRouter {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        let prefill_workers = self.worker_registry.get_workers_filtered(
            None,
            Some(WorkerType::Prefill {
                bootstrap_port: None,
            }),
            Some(crate::core::ConnectionMode::Grpc { port: None }),
            false,
        );
        let decode_workers = self.worker_registry.get_workers_filtered(
            None,
            Some(WorkerType::Decode),
            Some(crate::core::ConnectionMode::Grpc { port: None }),
            false,
        );
115
        f.debug_struct("GrpcPDRouter")
116
117
            .field("prefill_workers_count", &prefill_workers.len())
            .field("decode_workers_count", &decode_workers.len())
118
119
            .field("dp_aware", &self.dp_aware)
            .finish()
120
121
122
123
124
125
126
127
128
129
    }
}

#[async_trait]
impl RouterTrait for GrpcPDRouter {
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

    async fn health_generate(&self, _req: Request<Body>) -> Response {
130
131
132
133
134
135
        // TODO: Implement actual generation test for gRPC PD mode
        (
            StatusCode::NOT_IMPLEMENTED,
            "Health generate not yet implemented for gRPC PD",
        )
            .into_response()
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    }

    async fn get_server_info(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn get_models(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn get_model_info(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn route_generate(
        &self,
        _headers: Option<&HeaderMap>,
        _body: &crate::protocols::spec::GenerateRequest,
154
        _model_id: Option<&str>,
155
156
157
158
159
160
161
162
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn route_chat(
        &self,
        _headers: Option<&HeaderMap>,
        _body: &crate::protocols::spec::ChatCompletionRequest,
163
        _model_id: Option<&str>,
164
165
166
167
168
169
170
171
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn route_completion(
        &self,
        _headers: Option<&HeaderMap>,
        _body: &crate::protocols::spec::CompletionRequest,
172
        _model_id: Option<&str>,
173
174
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
175
176
177
178
179
180
    }

    async fn route_responses(
        &self,
        _headers: Option<&HeaderMap>,
        _body: &crate::protocols::spec::ResponsesRequest,
181
        _model_id: Option<&str>,
182
183
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
184
185
    }

186
187
188
189
190
191
    async fn get_response(
        &self,
        _headers: Option<&HeaderMap>,
        _response_id: &str,
        _params: &crate::protocols::spec::ResponsesGetParams,
    ) -> Response {
192
193
194
195
196
197
198
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

199
200
201
202
203
204
    async fn route_embeddings(
        &self,
        _headers: Option<&HeaderMap>,
        _body: &crate::protocols::spec::EmbeddingRequest,
        _model_id: Option<&str>,
    ) -> Response {
205
206
207
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

208
209
210
211
    async fn route_rerank(
        &self,
        _headers: Option<&HeaderMap>,
        _body: &crate::protocols::spec::RerankRequest,
212
        _model_id: Option<&str>,
213
    ) -> Response {
214
215
216
217
218
219
220
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    fn router_type(&self) -> &'static str {
        "grpc_pd"
    }
}