main.rs 3.69 KB
Newer Older
1
use clap::Parser;
2
use clap::ValueEnum;
3

4
use sglang_router_rs::{router::PolicyConfig, server};
5
6
7
8
9

#[derive(Debug, Clone, ValueEnum)]
pub enum PolicyType {
    Random,
    RoundRobin,
10
    CacheAware,
11
12
}

13
14
15
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
16
17
18
    #[arg(
        long,
        default_value = "127.0.0.1",
19
        help = "Host address to bind the router server to. Default: 127.0.0.1"
20
    )]
21
22
    host: String,

23
24
25
26
27
    #[arg(
        long,
        default_value_t = 3001,
        help = "Port number to bind the router server to. Default: 3001"
    )]
28
29
    port: u16,

30
31
32
    #[arg(
        long,
        value_delimiter = ',',
33
        help = "Comma-separated list of worker URLs that will handle the requests. Each URL should include the protocol, host, and port (e.g., http://worker1:8000,http://worker2:8000)"
34
    )]
35
36
    worker_urls: Vec<String>,

37
38
    #[arg(
        long,
39
        default_value_t = PolicyType::CacheAware,
40
        value_enum,
41
42
43
        help = "Load balancing policy to use for request distribution:\n\
              - random: Randomly select workers\n\
              - round_robin: Distribute requests in round-robin fashion\n\
44
              - cache_aware: Distribute requests based on cache state and load balance\n"
45
46
47
48
49
    )]
    policy: PolicyType,

    #[arg(
        long,
50
51
52
53
54
55
56
57
58
        default_value_t = 0.5,
        requires = "policy",
        required_if_eq("policy", "cache_aware"),
        help = "Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker if the match rate exceeds threshold, otherwise routes to the worker with the smallest tree. Default: 0.5"
    )]
    cache_threshold: f32,

    #[arg(
        long,
59
        default_value_t = 32,
60
61
        requires = "policy",
        required_if_eq("policy", "cache_aware"),
62
        help = "Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 32"
63
    )]
64
65
66
67
68
69
70
71
72
73
    balance_abs_threshold: usize,

    #[arg(
        long,
        default_value_t = 1.0001,
        requires = "policy",
        required_if_eq("policy", "cache_aware"),
        help = "Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001"
    )]
    balance_rel_threshold: f32,
74
75
76
77

    #[arg(
        long,
        default_value_t = 60,
78
        requires = "policy",
79
80
        required_if_eq("policy", "cache_aware"),
        help = "Interval in seconds between cache eviction operations in cache-aware routing. Default: 60"
81
    )]
82
    eviction_interval_secs: u64,
83
84
85

    #[arg(
        long,
86
        default_value_t = 2usize.pow(24),
87
        requires = "policy",
88
89
        required_if_eq("policy", "cache_aware"),
        help = "Maximum size of the approximation tree for cache-aware routing. Default: 2^24"
90
    )]
91
    max_tree_size: usize,
92
93
94
95
96
97
98
}

impl Args {
    fn get_policy_config(&self) -> PolicyConfig {
        match self.policy {
            PolicyType::Random => PolicyConfig::RandomConfig,
            PolicyType::RoundRobin => PolicyConfig::RoundRobinConfig,
99
100
            PolicyType::CacheAware => PolicyConfig::CacheAwareConfig {
                cache_threshold: self.cache_threshold,
101
102
                balance_abs_threshold: self.balance_abs_threshold,
                balance_rel_threshold: self.balance_rel_threshold,
103
104
                eviction_interval_secs: self.eviction_interval_secs,
                max_tree_size: self.max_tree_size,
105
106
107
            },
        }
    }
108
109
110
111
112
}

#[actix_web::main]
async fn main() -> std::io::Result<()> {
    let args = Args::parse();
113
114
    let policy_config = args.get_policy_config();
    server::startup(args.host, args.port, args.worker_urls, policy_config).await
115
}