main.rs 3.92 KB
Newer Older
1
use clap::Parser;
2
use clap::ValueEnum;
3

4
use sglang_router_rs::{router::PolicyConfig, server, server::ServerConfig};
5
6
7
8
9

#[derive(Debug, Clone, ValueEnum)]
pub enum PolicyType {
    Random,
    RoundRobin,
10
    CacheAware,
11
12
}

13
14
15
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
16
17
18
    #[arg(
        long,
        default_value = "127.0.0.1",
19
        help = "Host address to bind the router server to. Default: 127.0.0.1"
20
    )]
21
22
    host: String,

23
24
25
26
27
    #[arg(
        long,
        default_value_t = 3001,
        help = "Port number to bind the router server to. Default: 3001"
    )]
28
29
    port: u16,

30
31
32
    #[arg(
        long,
        value_delimiter = ',',
33
        help = "Comma-separated list of worker URLs that will handle the requests. Each URL should include the protocol, host, and port (e.g., http://worker1:8000,http://worker2:8000)"
34
    )]
35
36
    worker_urls: Vec<String>,

37
38
    #[arg(
        long,
39
        default_value_t = PolicyType::CacheAware,
40
        value_enum,
41
42
43
        help = "Load balancing policy to use for request distribution:\n\
              - random: Randomly select workers\n\
              - round_robin: Distribute requests in round-robin fashion\n\
44
              - cache_aware: Distribute requests based on cache state and load balance\n"
45
46
47
48
49
    )]
    policy: PolicyType,

    #[arg(
        long,
50
51
52
53
54
55
56
57
58
        default_value_t = 0.5,
        requires = "policy",
        required_if_eq("policy", "cache_aware"),
        help = "Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker if the match rate exceeds threshold, otherwise routes to the worker with the smallest tree. Default: 0.5"
    )]
    cache_threshold: f32,

    #[arg(
        long,
59
        default_value_t = 32,
60
61
        requires = "policy",
        required_if_eq("policy", "cache_aware"),
62
        help = "Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 32"
63
    )]
64
65
66
67
68
69
70
71
72
73
    balance_abs_threshold: usize,

    #[arg(
        long,
        default_value_t = 1.0001,
        requires = "policy",
        required_if_eq("policy", "cache_aware"),
        help = "Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001"
    )]
    balance_rel_threshold: f32,
74
75
76
77

    #[arg(
        long,
        default_value_t = 60,
78
        requires = "policy",
79
80
        required_if_eq("policy", "cache_aware"),
        help = "Interval in seconds between cache eviction operations in cache-aware routing. Default: 60"
81
    )]
82
    eviction_interval_secs: u64,
83
84
85

    #[arg(
        long,
86
        default_value_t = 2usize.pow(24),
87
        requires = "policy",
88
89
        required_if_eq("policy", "cache_aware"),
        help = "Maximum size of the approximation tree for cache-aware routing. Default: 2^24"
90
    )]
91
    max_tree_size: usize,
92
93
94

    #[arg(long, default_value_t = false, help = "Enable verbose logging")]
    verbose: bool,
95
96
97
98
99
100
101
}

impl Args {
    fn get_policy_config(&self) -> PolicyConfig {
        match self.policy {
            PolicyType::Random => PolicyConfig::RandomConfig,
            PolicyType::RoundRobin => PolicyConfig::RoundRobinConfig,
102
103
            PolicyType::CacheAware => PolicyConfig::CacheAwareConfig {
                cache_threshold: self.cache_threshold,
104
105
                balance_abs_threshold: self.balance_abs_threshold,
                balance_rel_threshold: self.balance_rel_threshold,
106
107
                eviction_interval_secs: self.eviction_interval_secs,
                max_tree_size: self.max_tree_size,
108
109
110
            },
        }
    }
111
112
113
114
115
}

#[actix_web::main]
async fn main() -> std::io::Result<()> {
    let args = Args::parse();
116
    let policy_config = args.get_policy_config();
117
118
119
120
121
122
123
124
    server::startup(ServerConfig {
        host: args.host,
        port: args.port,
        worker_urls: args.worker_urls,
        policy_config,
        verbose: args.verbose,
    })
    .await
125
}