factory.rs 7.16 KB
Newer Older
1
2
//! Factory for creating router instances

3
use super::{
4
    http::{openai_router::OpenAIRouter, pd_router::PDRouter, router::Router},
5
6
    RouterTrait,
};
7
use crate::config::{ConnectionMode, PolicyConfig, RoutingMode};
8
use crate::policies::PolicyFactory;
9
10
use crate::server::AppContext;
use std::sync::Arc;
11
12
13
14
15

/// Factory for creating router instances based on configuration
pub struct RouterFactory;

impl RouterFactory {
16
    /// Create a router instance from application context
17
    pub async fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
18
19
20
21
22
        // Check if IGW mode is enabled
        if ctx.router_config.enable_igw {
            return Self::create_igw_router(ctx).await;
        }

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
        // Check connection mode and route to appropriate implementation
        match ctx.router_config.connection_mode {
            ConnectionMode::Grpc => {
                // Route to gRPC implementation based on routing mode
                match &ctx.router_config.mode {
                    RoutingMode::Regular { worker_urls } => {
                        Self::create_grpc_router(worker_urls, &ctx.router_config.policy, ctx).await
                    }
                    RoutingMode::PrefillDecode {
                        prefill_urls,
                        decode_urls,
                        prefill_policy,
                        decode_policy,
                    } => {
                        Self::create_grpc_pd_router(
                            prefill_urls,
                            decode_urls,
                            prefill_policy.as_ref(),
                            decode_policy.as_ref(),
                            &ctx.router_config.policy,
                            ctx,
                        )
                        .await
                    }
47
48
49
                    RoutingMode::OpenAI { .. } => {
                        Err("OpenAI mode requires HTTP connection_mode".to_string())
                    }
50
                }
51
            }
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
            ConnectionMode::Http => {
                // Route to HTTP implementation based on routing mode
                match &ctx.router_config.mode {
                    RoutingMode::Regular { worker_urls } => {
                        Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx)
                            .await
                    }
                    RoutingMode::PrefillDecode {
                        prefill_urls,
                        decode_urls,
                        prefill_policy,
                        decode_policy,
                    } => {
                        Self::create_pd_router(
                            prefill_urls,
                            decode_urls,
                            prefill_policy.as_ref(),
                            decode_policy.as_ref(),
                            &ctx.router_config.policy,
                            ctx,
                        )
                        .await
                    }
75
76
77
                    RoutingMode::OpenAI { worker_urls, .. } => {
                        Self::create_openai_router(worker_urls.clone(), ctx).await
                    }
78
                }
79
            }
80
81
82
83
        }
    }

    /// Create a regular router with injected policy
84
    async fn create_regular_router(
85
86
        worker_urls: &[String],
        policy_config: &PolicyConfig,
87
        ctx: &Arc<AppContext>,
88
89
90
91
    ) -> Result<Box<dyn RouterTrait>, String> {
        // Create policy
        let policy = PolicyFactory::create_from_config(policy_config);

92
93
        // Create regular router with injected policy and context
        let router = Router::new(worker_urls.to_vec(), policy, ctx).await?;
94
95
96
97
98

        Ok(Box::new(router))
    }

    /// Create a PD router with injected policy
99
    async fn create_pd_router(
100
101
        prefill_urls: &[(String, Option<u16>)],
        decode_urls: &[String],
102
103
104
        prefill_policy_config: Option<&PolicyConfig>,
        decode_policy_config: Option<&PolicyConfig>,
        main_policy_config: &PolicyConfig,
105
        ctx: &Arc<AppContext>,
106
    ) -> Result<Box<dyn RouterTrait>, String> {
107
108
109
110
111
        // Create policies - use specific policies if provided, otherwise fall back to main policy
        let prefill_policy =
            PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config));
        let decode_policy =
            PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
112

113
        // Create PD router with separate policies and context
114
115
116
        let router = PDRouter::new(
            prefill_urls.to_vec(),
            decode_urls.to_vec(),
117
118
            prefill_policy,
            decode_policy,
119
            ctx,
120
121
        )
        .await?;
122
123
124

        Ok(Box::new(router))
    }
125

126
127
    /// Create a gRPC router with injected policy
    pub async fn create_grpc_router(
128
129
130
        worker_urls: &[String],
        policy_config: &PolicyConfig,
        ctx: &Arc<AppContext>,
131
    ) -> Result<Box<dyn RouterTrait>, String> {
132
133
134
135
136
        use super::grpc::router::GrpcRouter;

        // Create policy
        let policy = PolicyFactory::create_from_config(policy_config);

137
138
        // Create gRPC router with context
        let router = GrpcRouter::new(worker_urls.to_vec(), policy, ctx).await?;
139
140

        Ok(Box::new(router))
141
142
    }

143
    /// Create a gRPC PD router with tokenizer and worker configuration
144
    pub async fn create_grpc_pd_router(
145
146
147
148
149
150
        prefill_urls: &[(String, Option<u16>)],
        decode_urls: &[String],
        prefill_policy_config: Option<&PolicyConfig>,
        decode_policy_config: Option<&PolicyConfig>,
        main_policy_config: &PolicyConfig,
        ctx: &Arc<AppContext>,
151
    ) -> Result<Box<dyn RouterTrait>, String> {
152
153
154
155
156
157
158
159
        use super::grpc::pd_router::GrpcPDRouter;

        // Create policies - use specific policies if provided, otherwise fall back to main policy
        let prefill_policy =
            PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config));
        let decode_policy =
            PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));

160
        // Create gRPC PD router with context
161
162
163
164
165
        let router = GrpcPDRouter::new(
            prefill_urls.to_vec(),
            decode_urls.to_vec(),
            prefill_policy,
            decode_policy,
166
            ctx,
167
168
169
170
        )
        .await?;

        Ok(Box::new(router))
171
172
    }

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    /// Create an OpenAI router
    async fn create_openai_router(
        worker_urls: Vec<String>,
        ctx: &Arc<AppContext>,
    ) -> Result<Box<dyn RouterTrait>, String> {
        // Use the first worker URL as the OpenAI-compatible base
        let base_url = worker_urls
            .first()
            .cloned()
            .ok_or_else(|| "OpenAI mode requires at least one worker URL".to_string())?;

        let router =
            OpenAIRouter::new(base_url, Some(ctx.router_config.circuit_breaker.clone())).await?;

        Ok(Box::new(router))
    }

190
191
192
193
194
    /// Create an IGW router (placeholder for future implementation)
    async fn create_igw_router(_ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
        // For now, return an error indicating IGW is not yet implemented
        Err("IGW mode is not yet implemented".to_string())
    }
195
}