factory.rs 4.89 KB
Newer Older
1
2
//! Factory for creating router instances

3
4
use super::grpc::pd_router::GrpcPDRouter;
use super::grpc::router::GrpcRouter;
5
use super::{
6
7
    http::{pd_router::PDRouter, router::Router},
    openai::OpenAIRouter,
8
9
    RouterTrait,
};
10
use crate::config::{ConnectionMode, PolicyConfig, RoutingMode};
11
use crate::policies::PolicyFactory;
12
13
use crate::server::AppContext;
use std::sync::Arc;
14
15
16
17
18

/// Factory for creating router instances based on configuration
pub struct RouterFactory;

impl RouterFactory {
19
    /// Create a router instance from application context
20
    pub async fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
21
        match ctx.router_config.connection_mode {
22
23
24
25
26
27
28
29
30
31
32
33
34
35
            ConnectionMode::Grpc => match &ctx.router_config.mode {
                RoutingMode::Regular { .. } => Self::create_grpc_router(ctx).await,
                RoutingMode::PrefillDecode {
                    prefill_policy,
                    decode_policy,
                    ..
                } => {
                    Self::create_grpc_pd_router(
                        prefill_policy.as_ref(),
                        decode_policy.as_ref(),
                        &ctx.router_config.policy,
                        ctx,
                    )
                    .await
36
                }
37
38
                RoutingMode::OpenAI { .. } => {
                    Err("OpenAI mode requires HTTP connection_mode".to_string())
39
                }
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
            },
            ConnectionMode::Http => match &ctx.router_config.mode {
                RoutingMode::Regular { .. } => Self::create_regular_router(ctx).await,
                RoutingMode::PrefillDecode {
                    prefill_policy,
                    decode_policy,
                    ..
                } => {
                    Self::create_pd_router(
                        prefill_policy.as_ref(),
                        decode_policy.as_ref(),
                        &ctx.router_config.policy,
                        ctx,
                    )
                    .await
                }
                RoutingMode::OpenAI { worker_urls, .. } => {
                    Self::create_openai_router(worker_urls.clone(), ctx).await
                }
            },
60
61
62
        }
    }

63
64
    /// Create a regular router
    pub async fn create_regular_router(
65
        ctx: &Arc<AppContext>,
66
    ) -> Result<Box<dyn RouterTrait>, String> {
67
        let router = Router::new(ctx).await?;
68
69
70
71
72

        Ok(Box::new(router))
    }

    /// Create a PD router with injected policy
73
    pub async fn create_pd_router(
74
75
76
        prefill_policy_config: Option<&PolicyConfig>,
        decode_policy_config: Option<&PolicyConfig>,
        main_policy_config: &PolicyConfig,
77
        ctx: &Arc<AppContext>,
78
    ) -> Result<Box<dyn RouterTrait>, String> {
79
80
81
82
        let prefill_policy =
            PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config));
        let decode_policy =
            PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
83

84
85
86
        ctx.policy_registry.set_prefill_policy(prefill_policy);
        ctx.policy_registry.set_decode_policy(decode_policy);

87
        let router = PDRouter::new(ctx).await?;
88
89
90

        Ok(Box::new(router))
    }
91

92
    /// Create a gRPC router with injected policy
93
94
    pub async fn create_grpc_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
        let router = GrpcRouter::new(ctx).await?;
95
96

        Ok(Box::new(router))
97
98
    }

99
    /// Create a gRPC PD router with tokenizer and worker configuration
100
    pub async fn create_grpc_pd_router(
101
102
103
104
        prefill_policy_config: Option<&PolicyConfig>,
        decode_policy_config: Option<&PolicyConfig>,
        main_policy_config: &PolicyConfig,
        ctx: &Arc<AppContext>,
105
    ) -> Result<Box<dyn RouterTrait>, String> {
106
107
108
109
110
        let prefill_policy =
            PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config));
        let decode_policy =
            PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));

111
112
113
        ctx.policy_registry.set_prefill_policy(prefill_policy);
        ctx.policy_registry.set_decode_policy(decode_policy);
        let router = GrpcPDRouter::new(ctx).await?;
114
115

        Ok(Box::new(router))
116
117
    }

118
119
120
121
122
123
124
125
126
127
    /// Create an OpenAI router
    async fn create_openai_router(
        worker_urls: Vec<String>,
        ctx: &Arc<AppContext>,
    ) -> Result<Box<dyn RouterTrait>, String> {
        let base_url = worker_urls
            .first()
            .cloned()
            .ok_or_else(|| "OpenAI mode requires at least one worker URL".to_string())?;

128
129
130
131
        let router = OpenAIRouter::new(
            base_url,
            Some(ctx.router_config.circuit_breaker.clone()),
            ctx.response_storage.clone(),
132
            ctx.conversation_storage.clone(),
133
            ctx.conversation_item_storage.clone(),
134
135
        )
        .await?;
136
137
138

        Ok(Box::new(router))
    }
139
}