pd_router.rs 12.9 KB
Newer Older
1
2
// PD (Prefill-Decode) gRPC Router Implementation

3
use crate::config::types::RetryConfig;
4
use crate::core::{
5
    BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, WorkerRegistry, WorkerType,
6
};
7
use crate::grpc_client::SglangSchedulerClient;
8
use crate::metrics::RouterMetrics;
9
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
10
use crate::reasoning_parser::ParserFactory;
11
use crate::routers::RouterTrait;
12
use crate::tokenizer::traits::Tokenizer;
13
use crate::tool_parser::ParserRegistry;
14
15
16
17
18
19
20
use async_trait::async_trait;
use axum::{
    body::Body,
    extract::Request,
    http::{HeaderMap, StatusCode},
    response::{IntoResponse, Response},
};
21
use std::collections::HashMap;
22
use std::sync::Arc;
23
24
use std::time::Duration;
use tracing::{info, warn};
25

26
27
28
/// gRPC PD (Prefill-Decode) router implementation for SGLang
#[allow(dead_code)] // Fields will be used once implementation is complete
pub struct GrpcPDRouter {
29
30
31
32
    /// Centralized worker registry
    worker_registry: Arc<WorkerRegistry>,
    /// Centralized policy registry
    policy_registry: Arc<PolicyRegistry>,
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    /// Load balancing policy for prefill
    prefill_policy: Arc<dyn LoadBalancingPolicy>,
    /// Load balancing policy for decode
    decode_policy: Arc<dyn LoadBalancingPolicy>,
    /// Tokenizer for handling text encoding/decoding
    tokenizer: Arc<dyn Tokenizer>,
    /// Reasoning parser factory for structured reasoning outputs
    reasoning_parser_factory: ParserFactory,
    /// Tool parser registry for function/tool calls
    tool_parser_registry: &'static ParserRegistry,
    /// Configuration
    timeout_secs: u64,
    interval_secs: u64,
    dp_aware: bool,
    api_key: Option<String>,
    retry_config: RetryConfig,
    circuit_breaker_config: CircuitBreakerConfig,
}
51
52

impl GrpcPDRouter {
53
54
55
56
57
58
    /// Create a new gRPC PD router
    pub async fn new(
        prefill_urls: Vec<(String, Option<u16>)>,
        decode_urls: Vec<String>,
        prefill_policy: Arc<dyn LoadBalancingPolicy>,
        decode_policy: Arc<dyn LoadBalancingPolicy>,
59
        ctx: &Arc<crate::server::AppContext>,
60
    ) -> Result<Self, String> {
61
62
63
64
        // Get registries from context
        let worker_registry = ctx.worker_registry.clone();
        let policy_registry = ctx.policy_registry.clone();

65
66
67
        // Update metrics
        RouterMetrics::set_active_workers(prefill_urls.len() + decode_urls.len());

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        // Extract necessary components from context
        let tokenizer = ctx
            .tokenizer
            .as_ref()
            .ok_or_else(|| "gRPC PD router requires tokenizer".to_string())?
            .clone();
        let reasoning_parser_factory = ctx
            .reasoning_parser_factory
            .as_ref()
            .ok_or_else(|| "gRPC PD router requires reasoning parser factory".to_string())?
            .clone();
        let tool_parser_registry = ctx
            .tool_parser_registry
            .ok_or_else(|| "gRPC PD router requires tool parser registry".to_string())?;

83
        // Convert config CircuitBreakerConfig to core CircuitBreakerConfig
84
        let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        let core_cb_config = CircuitBreakerConfig {
            failure_threshold: circuit_breaker_config.failure_threshold,
            success_threshold: circuit_breaker_config.success_threshold,
            timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
            window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
        };

        // Create gRPC clients for prefill workers
        let mut prefill_grpc_clients = HashMap::new();
        for (url, _bootstrap_port) in &prefill_urls {
            match SglangSchedulerClient::connect(url).await {
                Ok(client) => {
                    prefill_grpc_clients.insert(url.clone(), client);
                    info!("Connected to gRPC prefill worker at {}", url);
                }
                Err(e) => {
                    warn!("Failed to connect to gRPC prefill worker at {}: {}", url, e);
                    // Continue with other workers
                }
            }
        }

        // Create gRPC clients for decode workers
        let mut decode_grpc_clients = HashMap::new();
        for url in &decode_urls {
            match SglangSchedulerClient::connect(url).await {
                Ok(client) => {
                    decode_grpc_clients.insert(url.clone(), client);
                    info!("Connected to gRPC decode worker at {}", url);
                }
                Err(e) => {
                    warn!("Failed to connect to gRPC decode worker at {}: {}", url, e);
                    // Continue with other workers
                }
            }
        }

        if prefill_grpc_clients.is_empty() && decode_grpc_clients.is_empty() {
            return Err("Failed to connect to any gRPC workers".to_string());
        }

126
127
128
        // Create Prefill Worker trait objects with gRPC connection mode and register them
        for (url, bootstrap_port) in &prefill_urls {
            if let Some(client) = prefill_grpc_clients.remove(url) {
129
130
                let worker = BasicWorkerBuilder::new(url.clone())
                    .worker_type(WorkerType::Prefill {
131
                        bootstrap_port: *bootstrap_port,
132
133
                    })
                    .connection_mode(crate::core::ConnectionMode::Grpc {
134
                        port: *bootstrap_port,
135
136
137
138
139
140
141
142
143
                    })
                    .circuit_breaker_config(core_cb_config.clone())
                    .health_config(HealthConfig {
                        timeout_secs: ctx.router_config.health_check.timeout_secs,
                        check_interval_secs: ctx.router_config.health_check.check_interval_secs,
                        endpoint: ctx.router_config.health_check.endpoint.clone(),
                        failure_threshold: ctx.router_config.health_check.failure_threshold,
                        success_threshold: ctx.router_config.health_check.success_threshold,
                    })
144
                    .grpc_client(client)
145
                    .build();
146
147
148
149
150
151
152
153
154

                // Register worker in the centralized registry
                worker_registry.register(Arc::new(worker));
            }
        }

        // Create Decode Worker trait objects with gRPC connection mode and register them
        for url in &decode_urls {
            if let Some(client) = decode_grpc_clients.remove(url) {
155
156
157
158
159
160
161
162
163
164
165
                let worker = BasicWorkerBuilder::new(url.clone())
                    .worker_type(WorkerType::Decode)
                    .connection_mode(crate::core::ConnectionMode::Grpc { port: None })
                    .circuit_breaker_config(core_cb_config.clone())
                    .health_config(HealthConfig {
                        timeout_secs: ctx.router_config.health_check.timeout_secs,
                        check_interval_secs: ctx.router_config.health_check.check_interval_secs,
                        endpoint: ctx.router_config.health_check.endpoint.clone(),
                        failure_threshold: ctx.router_config.health_check.failure_threshold,
                        success_threshold: ctx.router_config.health_check.success_threshold,
                    })
166
                    .grpc_client(client)
167
                    .build();
168

169
170
171
172
173
174
175
176
177
178
179
180
181
182
                // Register worker in the centralized registry
                worker_registry.register(Arc::new(worker));
            }
        }

        // Initialize policies with workers if needed - filter for gRPC workers only
        let prefill_workers = worker_registry.get_workers_filtered(
            None, // any model
            Some(WorkerType::Prefill {
                bootstrap_port: None,
            }),
            Some(crate::core::ConnectionMode::Grpc { port: None }),
            false, // include unhealthy workers during initialization
        );
183
184
185
186
187
188
189
        if let Some(cache_aware) = prefill_policy
            .as_any()
            .downcast_ref::<crate::policies::CacheAwarePolicy>()
        {
            cache_aware.init_workers(&prefill_workers);
        }

190
191
192
193
194
195
        let decode_workers = worker_registry.get_workers_filtered(
            None, // any model
            Some(WorkerType::Decode),
            Some(crate::core::ConnectionMode::Grpc { port: None }),
            false, // include unhealthy workers during initialization
        );
196
197
198
199
200
201
202
        if let Some(cache_aware) = decode_policy
            .as_any()
            .downcast_ref::<crate::policies::CacheAwarePolicy>()
        {
            cache_aware.init_workers(&decode_workers);
        }

203
        // No need for local health checkers - WorkerRegistry handles health checking
204
205

        Ok(GrpcPDRouter {
206
207
            worker_registry,
            policy_registry,
208
209
210
211
212
            prefill_policy,
            decode_policy,
            tokenizer,
            reasoning_parser_factory,
            tool_parser_registry,
213
214
215
216
217
            timeout_secs: ctx.router_config.worker_startup_timeout_secs,
            interval_secs: ctx.router_config.worker_startup_check_interval_secs,
            dp_aware: ctx.router_config.dp_aware,
            api_key: ctx.router_config.api_key.clone(),
            retry_config: ctx.router_config.effective_retry_config(),
218
219
220
221
222
223
224
            circuit_breaker_config: core_cb_config,
        })
    }
}

impl std::fmt::Debug for GrpcPDRouter {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
225
226
227
228
229
230
231
232
233
234
235
236
237
238
        let prefill_workers = self.worker_registry.get_workers_filtered(
            None,
            Some(WorkerType::Prefill {
                bootstrap_port: None,
            }),
            Some(crate::core::ConnectionMode::Grpc { port: None }),
            false,
        );
        let decode_workers = self.worker_registry.get_workers_filtered(
            None,
            Some(WorkerType::Decode),
            Some(crate::core::ConnectionMode::Grpc { port: None }),
            false,
        );
239
        f.debug_struct("GrpcPDRouter")
240
241
            .field("prefill_workers_count", &prefill_workers.len())
            .field("decode_workers_count", &decode_workers.len())
242
243
244
245
            .field("timeout_secs", &self.timeout_secs)
            .field("interval_secs", &self.interval_secs)
            .field("dp_aware", &self.dp_aware)
            .finish()
246
247
248
249
250
251
252
253
254
255
    }
}

#[async_trait]
impl RouterTrait for GrpcPDRouter {
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

    async fn health_generate(&self, _req: Request<Body>) -> Response {
256
257
258
259
260
261
        // TODO: Implement actual generation test for gRPC PD mode
        (
            StatusCode::NOT_IMPLEMENTED,
            "Health generate not yet implemented for gRPC PD",
        )
            .into_response()
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
    }

    async fn get_server_info(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn get_models(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn get_model_info(&self, _req: Request<Body>) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn route_generate(
        &self,
        _headers: Option<&HeaderMap>,
        _body: &crate::protocols::spec::GenerateRequest,
280
        _model_id: Option<&str>,
281
282
283
284
285
286
287
288
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn route_chat(
        &self,
        _headers: Option<&HeaderMap>,
        _body: &crate::protocols::spec::ChatCompletionRequest,
289
        _model_id: Option<&str>,
290
291
292
293
294
295
296
297
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn route_completion(
        &self,
        _headers: Option<&HeaderMap>,
        _body: &crate::protocols::spec::CompletionRequest,
298
        _model_id: Option<&str>,
299
300
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
301
302
303
304
305
306
    }

    async fn route_responses(
        &self,
        _headers: Option<&HeaderMap>,
        _body: &crate::protocols::spec::ResponsesRequest,
307
        _model_id: Option<&str>,
308
309
    ) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
310
311
    }

312
313
314
315
316
317
    async fn get_response(
        &self,
        _headers: Option<&HeaderMap>,
        _response_id: &str,
        _params: &crate::protocols::spec::ResponsesGetParams,
    ) -> Response {
318
319
320
321
322
323
324
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

325
326
327
328
329
330
    async fn route_embeddings(
        &self,
        _headers: Option<&HeaderMap>,
        _body: &crate::protocols::spec::EmbeddingRequest,
        _model_id: Option<&str>,
    ) -> Response {
331
332
333
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

334
335
336
337
    async fn route_rerank(
        &self,
        _headers: Option<&HeaderMap>,
        _body: &crate::protocols::spec::RerankRequest,
338
        _model_id: Option<&str>,
339
    ) -> Response {
340
341
342
343
344
345
346
        (StatusCode::NOT_IMPLEMENTED).into_response()
    }

    fn router_type(&self) -> &'static str {
        "grpc_pd"
    }
}