factory.rs 9.66 KB
Newer Older
1
2
//! Factory for creating router instances

3
4
5
6
use super::{
    http::{pd_router::PDRouter, router::Router},
    RouterTrait,
};
7
use crate::config::{ConnectionMode, PolicyConfig, RoutingMode};
8
use crate::policies::PolicyFactory;
9
10
use crate::server::AppContext;
use std::sync::Arc;
11
12
13
14
15

/// Factory for creating router instances based on configuration
pub struct RouterFactory;

impl RouterFactory {
16
    /// Create a router instance from application context
17
    pub async fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
18
19
20
21
22
        // Check if IGW mode is enabled
        if ctx.router_config.enable_igw {
            return Self::create_igw_router(ctx).await;
        }

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
        // Check connection mode and route to appropriate implementation
        match ctx.router_config.connection_mode {
            ConnectionMode::Grpc => {
                // Route to gRPC implementation based on routing mode
                match &ctx.router_config.mode {
                    RoutingMode::Regular { worker_urls } => {
                        Self::create_grpc_router(worker_urls, &ctx.router_config.policy, ctx).await
                    }
                    RoutingMode::PrefillDecode {
                        prefill_urls,
                        decode_urls,
                        prefill_policy,
                        decode_policy,
                    } => {
                        Self::create_grpc_pd_router(
                            prefill_urls,
                            decode_urls,
                            prefill_policy.as_ref(),
                            decode_policy.as_ref(),
                            &ctx.router_config.policy,
                            ctx,
                        )
                        .await
                    }
                }
48
            }
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
            ConnectionMode::Http => {
                // Route to HTTP implementation based on routing mode
                match &ctx.router_config.mode {
                    RoutingMode::Regular { worker_urls } => {
                        Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx)
                            .await
                    }
                    RoutingMode::PrefillDecode {
                        prefill_urls,
                        decode_urls,
                        prefill_policy,
                        decode_policy,
                    } => {
                        Self::create_pd_router(
                            prefill_urls,
                            decode_urls,
                            prefill_policy.as_ref(),
                            decode_policy.as_ref(),
                            &ctx.router_config.policy,
                            ctx,
                        )
                        .await
                    }
                }
73
            }
74
75
76
77
        }
    }

    /// Create a regular router with injected policy
78
    async fn create_regular_router(
79
80
        worker_urls: &[String],
        policy_config: &PolicyConfig,
81
        ctx: &Arc<AppContext>,
82
83
84
85
    ) -> Result<Box<dyn RouterTrait>, String> {
        // Create policy
        let policy = PolicyFactory::create_from_config(policy_config);

86
        // Create regular router with injected policy and client
87
88
89
        let router = Router::new(
            worker_urls.to_vec(),
            policy,
90
91
92
93
94
            ctx.client.clone(),
            ctx.router_config.worker_startup_timeout_secs,
            ctx.router_config.worker_startup_check_interval_secs,
            ctx.router_config.dp_aware,
            ctx.router_config.api_key.clone(),
95
            ctx.router_config.retry.clone(),
96
            ctx.router_config.circuit_breaker.clone(),
97
            ctx.router_config.health_check.clone(),
98
99
        )
        .await?;
100
101
102
103
104

        Ok(Box::new(router))
    }

    /// Create a PD router with injected policy
105
    async fn create_pd_router(
106
107
        prefill_urls: &[(String, Option<u16>)],
        decode_urls: &[String],
108
109
110
        prefill_policy_config: Option<&PolicyConfig>,
        decode_policy_config: Option<&PolicyConfig>,
        main_policy_config: &PolicyConfig,
111
        ctx: &Arc<AppContext>,
112
    ) -> Result<Box<dyn RouterTrait>, String> {
113
114
115
116
117
        // Create policies - use specific policies if provided, otherwise fall back to main policy
        let prefill_policy =
            PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config));
        let decode_policy =
            PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
118

119
        // Create PD router with separate policies and client
120
121
122
        let router = PDRouter::new(
            prefill_urls.to_vec(),
            decode_urls.to_vec(),
123
124
            prefill_policy,
            decode_policy,
125
            ctx.client.clone(),
126
            ctx.router_config.request_timeout_secs,
127
128
            ctx.router_config.worker_startup_timeout_secs,
            ctx.router_config.worker_startup_check_interval_secs,
129
            ctx.router_config.retry.clone(),
130
            ctx.router_config.circuit_breaker.clone(),
131
            ctx.router_config.health_check.clone(),
132
133
        )
        .await?;
134
135
136

        Ok(Box::new(router))
    }
137

138
139
    /// Create a gRPC router with injected policy
    pub async fn create_grpc_router(
140
141
142
        worker_urls: &[String],
        policy_config: &PolicyConfig,
        ctx: &Arc<AppContext>,
143
    ) -> Result<Box<dyn RouterTrait>, String> {
144
145
146
147
148
        use super::grpc::router::GrpcRouter;

        // Create policy
        let policy = PolicyFactory::create_from_config(policy_config);

149
150
151
152
        // Get tokenizer from context
        let tokenizer = ctx
            .tokenizer
            .as_ref()
153
            .ok_or_else(|| {
154
155
156
157
158
159
160
161
162
163
                "gRPC router requires tokenizer to be initialized in AppContext".to_string()
            })?
            .clone();

        // Get reasoning parser factory from context
        let reasoning_parser_factory = ctx
            .reasoning_parser_factory
            .as_ref()
            .ok_or_else(|| {
                "gRPC router requires reasoning parser factory to be initialized in AppContext"
164
                    .to_string()
165
166
167
168
169
170
171
            })?
            .clone();

        // Get tool parser registry from context
        let tool_parser_registry = ctx.tool_parser_registry.ok_or_else(|| {
            "gRPC router requires tool parser registry to be initialized in AppContext".to_string()
        })?;
172
173
174
175
176
177
178
179
180
181
182
183

        // Create gRPC router
        let router = GrpcRouter::new(
            worker_urls.to_vec(),
            policy,
            ctx.router_config.worker_startup_timeout_secs,
            ctx.router_config.worker_startup_check_interval_secs,
            ctx.router_config.dp_aware,
            ctx.router_config.api_key.clone(),
            ctx.router_config.effective_retry_config(),
            ctx.router_config.effective_circuit_breaker_config(),
            ctx.router_config.health_check.clone(),
184
185
186
            tokenizer,
            reasoning_parser_factory,
            tool_parser_registry,
187
188
189
190
        )
        .await?;

        Ok(Box::new(router))
191
192
    }

193
    /// Create a gRPC PD router with tokenizer and worker configuration
194
    pub async fn create_grpc_pd_router(
195
196
197
198
199
200
        prefill_urls: &[(String, Option<u16>)],
        decode_urls: &[String],
        prefill_policy_config: Option<&PolicyConfig>,
        decode_policy_config: Option<&PolicyConfig>,
        main_policy_config: &PolicyConfig,
        ctx: &Arc<AppContext>,
201
    ) -> Result<Box<dyn RouterTrait>, String> {
202
203
204
205
206
207
208
209
        use super::grpc::pd_router::GrpcPDRouter;

        // Create policies - use specific policies if provided, otherwise fall back to main policy
        let prefill_policy =
            PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config));
        let decode_policy =
            PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));

210
211
212
213
        // Get tokenizer from context
        let tokenizer = ctx
            .tokenizer
            .as_ref()
214
            .ok_or_else(|| {
215
216
217
218
219
220
221
222
223
224
                "gRPC PD router requires tokenizer to be initialized in AppContext".to_string()
            })?
            .clone();

        // Get reasoning parser factory from context
        let reasoning_parser_factory = ctx
            .reasoning_parser_factory
            .as_ref()
            .ok_or_else(|| {
                "gRPC PD router requires reasoning parser factory to be initialized in AppContext"
225
                    .to_string()
226
227
228
229
230
231
232
233
            })?
            .clone();

        // Get tool parser registry from context
        let tool_parser_registry = ctx.tool_parser_registry.ok_or_else(|| {
            "gRPC PD router requires tool parser registry to be initialized in AppContext"
                .to_string()
        })?;
234
235
236
237
238
239
240
241
242
243
244
245
246
247

        // Create gRPC PD router
        let router = GrpcPDRouter::new(
            prefill_urls.to_vec(),
            decode_urls.to_vec(),
            prefill_policy,
            decode_policy,
            ctx.router_config.worker_startup_timeout_secs,
            ctx.router_config.worker_startup_check_interval_secs,
            ctx.router_config.dp_aware,
            ctx.router_config.api_key.clone(),
            ctx.router_config.effective_retry_config(),
            ctx.router_config.effective_circuit_breaker_config(),
            ctx.router_config.health_check.clone(),
248
249
250
            tokenizer,
            reasoning_parser_factory,
            tool_parser_registry,
251
252
253
254
        )
        .await?;

        Ok(Box::new(router))
255
256
    }

257
258
259
260
261
    /// Create an IGW router (placeholder for future implementation)
    async fn create_igw_router(_ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
        // For now, return an error indicating IGW is not yet implemented
        Err("IGW mode is not yet implemented".to_string())
    }
262
}