factory.rs 8.69 KB
Newer Older
1
2
//! Factory for creating router instances

3
4
5
6
use super::{
    http::{pd_router::PDRouter, router::Router},
    RouterTrait,
};
7
use crate::config::{ConnectionMode, PolicyConfig, RoutingMode};
8
use crate::policies::PolicyFactory;
9
10
use crate::server::AppContext;
use std::sync::Arc;
11
12
13
14
15

/// Factory for creating router instances based on configuration
pub struct RouterFactory;

impl RouterFactory {
16
    /// Create a router instance from application context
17
    pub async fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
18
19
20
21
22
        // Check if IGW mode is enabled
        if ctx.router_config.enable_igw {
            return Self::create_igw_router(ctx).await;
        }

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
        // Check connection mode and route to appropriate implementation
        match ctx.router_config.connection_mode {
            ConnectionMode::Grpc => {
                // Route to gRPC implementation based on routing mode
                match &ctx.router_config.mode {
                    RoutingMode::Regular { worker_urls } => {
                        Self::create_grpc_router(worker_urls, &ctx.router_config.policy, ctx).await
                    }
                    RoutingMode::PrefillDecode {
                        prefill_urls,
                        decode_urls,
                        prefill_policy,
                        decode_policy,
                    } => {
                        Self::create_grpc_pd_router(
                            prefill_urls,
                            decode_urls,
                            prefill_policy.as_ref(),
                            decode_policy.as_ref(),
                            &ctx.router_config.policy,
                            ctx,
                        )
                        .await
                    }
                }
48
            }
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
            ConnectionMode::Http => {
                // Route to HTTP implementation based on routing mode
                match &ctx.router_config.mode {
                    RoutingMode::Regular { worker_urls } => {
                        Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx)
                            .await
                    }
                    RoutingMode::PrefillDecode {
                        prefill_urls,
                        decode_urls,
                        prefill_policy,
                        decode_policy,
                    } => {
                        Self::create_pd_router(
                            prefill_urls,
                            decode_urls,
                            prefill_policy.as_ref(),
                            decode_policy.as_ref(),
                            &ctx.router_config.policy,
                            ctx,
                        )
                        .await
                    }
                }
73
            }
74
75
76
77
        }
    }

    /// Create a regular router with injected policy
78
    async fn create_regular_router(
79
80
        worker_urls: &[String],
        policy_config: &PolicyConfig,
81
        ctx: &Arc<AppContext>,
82
83
84
85
    ) -> Result<Box<dyn RouterTrait>, String> {
        // Create policy
        let policy = PolicyFactory::create_from_config(policy_config);

86
        // Create regular router with injected policy and client
87
88
89
        let router = Router::new(
            worker_urls.to_vec(),
            policy,
90
91
92
93
94
            ctx.client.clone(),
            ctx.router_config.worker_startup_timeout_secs,
            ctx.router_config.worker_startup_check_interval_secs,
            ctx.router_config.dp_aware,
            ctx.router_config.api_key.clone(),
95
            ctx.router_config.retry.clone(),
96
            ctx.router_config.circuit_breaker.clone(),
97
            ctx.router_config.health_check.clone(),
98
99
        )
        .await?;
100
101
102
103
104

        Ok(Box::new(router))
    }

    /// Create a PD router with injected policy
105
    async fn create_pd_router(
106
107
        prefill_urls: &[(String, Option<u16>)],
        decode_urls: &[String],
108
109
110
        prefill_policy_config: Option<&PolicyConfig>,
        decode_policy_config: Option<&PolicyConfig>,
        main_policy_config: &PolicyConfig,
111
        ctx: &Arc<AppContext>,
112
    ) -> Result<Box<dyn RouterTrait>, String> {
113
114
115
116
117
        // Create policies - use specific policies if provided, otherwise fall back to main policy
        let prefill_policy =
            PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config));
        let decode_policy =
            PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
118

119
        // Create PD router with separate policies and client
120
121
122
        let router = PDRouter::new(
            prefill_urls.to_vec(),
            decode_urls.to_vec(),
123
124
            prefill_policy,
            decode_policy,
125
            ctx.client.clone(),
126
            ctx.router_config.request_timeout_secs,
127
128
            ctx.router_config.worker_startup_timeout_secs,
            ctx.router_config.worker_startup_check_interval_secs,
129
            ctx.router_config.retry.clone(),
130
            ctx.router_config.circuit_breaker.clone(),
131
            ctx.router_config.health_check.clone(),
132
133
        )
        .await?;
134
135
136

        Ok(Box::new(router))
    }
137

138
139
    /// Create a gRPC router with injected policy
    pub async fn create_grpc_router(
140
141
142
        worker_urls: &[String],
        policy_config: &PolicyConfig,
        ctx: &Arc<AppContext>,
143
    ) -> Result<Box<dyn RouterTrait>, String> {
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        use super::grpc::router::GrpcRouter;

        // Create policy
        let policy = PolicyFactory::create_from_config(policy_config);

        // Determine which tokenizer path to use
        // Priority: tokenizer_path > model_path
        let tokenizer_path = ctx
            .router_config
            .tokenizer_path
            .clone()
            .or_else(|| ctx.router_config.model_path.clone())
            .ok_or_else(|| {
                "gRPC router requires either --tokenizer-path or --model-path to be specified"
                    .to_string()
            })?;

        // Create gRPC router
        let router = GrpcRouter::new(
            worker_urls.to_vec(),
            policy,
            ctx.router_config.worker_startup_timeout_secs,
            ctx.router_config.worker_startup_check_interval_secs,
            ctx.router_config.dp_aware,
            ctx.router_config.api_key.clone(),
            ctx.router_config.effective_retry_config(),
            ctx.router_config.effective_circuit_breaker_config(),
            ctx.router_config.health_check.clone(),
            tokenizer_path,
        )
        .await?;

        Ok(Box::new(router))
177
178
    }

179
    /// Create a gRPC PD router with tokenizer and worker configuration
180
    pub async fn create_grpc_pd_router(
181
182
183
184
185
186
        prefill_urls: &[(String, Option<u16>)],
        decode_urls: &[String],
        prefill_policy_config: Option<&PolicyConfig>,
        decode_policy_config: Option<&PolicyConfig>,
        main_policy_config: &PolicyConfig,
        ctx: &Arc<AppContext>,
187
    ) -> Result<Box<dyn RouterTrait>, String> {
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        use super::grpc::pd_router::GrpcPDRouter;

        // Create policies - use specific policies if provided, otherwise fall back to main policy
        let prefill_policy =
            PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config));
        let decode_policy =
            PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));

        // Determine which tokenizer path to use
        // Priority: tokenizer_path > model_path
        let tokenizer_path = ctx
            .router_config
            .tokenizer_path
            .clone()
            .or_else(|| ctx.router_config.model_path.clone())
            .ok_or_else(|| {
                "gRPC PD router requires either --tokenizer-path or --model-path to be specified"
                    .to_string()
            })?;

        // Create gRPC PD router
        let router = GrpcPDRouter::new(
            prefill_urls.to_vec(),
            decode_urls.to_vec(),
            prefill_policy,
            decode_policy,
            ctx.router_config.worker_startup_timeout_secs,
            ctx.router_config.worker_startup_check_interval_secs,
            ctx.router_config.dp_aware,
            ctx.router_config.api_key.clone(),
            ctx.router_config.effective_retry_config(),
            ctx.router_config.effective_circuit_breaker_config(),
            ctx.router_config.health_check.clone(),
            tokenizer_path,
        )
        .await?;

        Ok(Box::new(router))
226
227
    }

228
229
230
231
232
    /// Create an IGW router (placeholder for future implementation)
    async fn create_igw_router(_ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
        // For now, return an error indicating IGW is not yet implemented
        Err("IGW mode is not yet implemented".to_string())
    }
233
}