main.rs 3.34 KB
Newer Older
1
// src/main.rs
2
use clap::Parser;
3
use clap::ValueEnum;
4

5
use sglang_router_rs::{router::PolicyConfig, server};
6
7
8
9
10

#[derive(Debug, Clone, ValueEnum)]
pub enum PolicyType {
    Random,
    RoundRobin,
11
    CacheAware,
12
13
}

14
15
16
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
17
18
19
    #[arg(
        long,
        default_value = "127.0.0.1",
20
        help = "Host address to bind the router server to. Default: 127.0.0.1"
21
    )]
22
23
    host: String,

24
25
26
27
28
    #[arg(
        long,
        default_value_t = 3001,
        help = "Port number to bind the router server to. Default: 3001"
    )]
29
30
    port: u16,

31
32
33
    #[arg(
        long,
        value_delimiter = ',',
34
        help = "Comma-separated list of worker URLs that will handle the requests. Each URL should include the protocol, host, and port (e.g., http://worker1:8000,http://worker2:8000)"
35
    )]
36
37
    worker_urls: Vec<String>,

38
39
    #[arg(
        long,
40
        default_value_t = PolicyType::CacheAware,
41
        value_enum,
42
43
44
45
        help = "Load balancing policy to use for request distribution:\n\
              - random: Randomly select workers\n\
              - round_robin: Distribute requests in round-robin fashion\n\
              - cache_aware: Distribute requests in cache-aware fashion\n"
46
47
48
49
50
    )]
    policy: PolicyType,

    #[arg(
        long,
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
        default_value_t = 0.5,
        requires = "policy",
        required_if_eq("policy", "cache_aware"),
        help = "Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker if the match rate exceeds threshold, otherwise routes to the worker with the smallest tree. Default: 0.5"
    )]
    cache_threshold: f32,

    #[arg(
        long,
        default_value_t = 1.0,
        requires = "policy",
        required_if_eq("policy", "cache_aware"),
        help = "Probability of using cache-aware routing (0.0-1.0). Default 1.0 for full cache-aware routing, suitable for perfectly divided prefix workloads. For uneven workloads, use a lower value to better distribute requests"
    )]
    cache_routing_prob: f32,

    #[arg(
        long,
        default_value_t = 60,
70
        requires = "policy",
71
72
        required_if_eq("policy", "cache_aware"),
        help = "Interval in seconds between cache eviction operations in cache-aware routing. Default: 60"
73
    )]
74
    eviction_interval_secs: u64,
75
76
77

    #[arg(
        long,
78
        default_value_t = 2usize.pow(24),
79
        requires = "policy",
80
81
        required_if_eq("policy", "cache_aware"),
        help = "Maximum size of the approximation tree for cache-aware routing. Default: 2^24"
82
    )]
83
    max_tree_size: usize,
84
85
86
87
88
89
90
}

impl Args {
    fn get_policy_config(&self) -> PolicyConfig {
        match self.policy {
            PolicyType::Random => PolicyConfig::RandomConfig,
            PolicyType::RoundRobin => PolicyConfig::RoundRobinConfig,
91
92
93
94
95
            PolicyType::CacheAware => PolicyConfig::CacheAwareConfig {
                cache_threshold: self.cache_threshold,
                cache_routing_prob: self.cache_routing_prob,
                eviction_interval_secs: self.eviction_interval_secs,
                max_tree_size: self.max_tree_size,
96
97
98
            },
        }
    }
99
100
101
102
103
}

#[actix_web::main]
async fn main() -> std::io::Result<()> {
    let args = Args::parse();
104
105
    let policy_config = args.get_policy_config();
    server::startup(args.host, args.port, args.worker_urls, policy_config).await
106
}