main.rs 2.37 KB
Newer Older
1
// src/main.rs
2
use clap::Parser;
3
use clap::ValueEnum;
4
5
// declare child modules
mod router;
6
7
mod server;
mod tree;
8

9
10
11
12
13
14
15
16
17
use crate::router::PolicyConfig;

#[derive(Debug, Clone, ValueEnum)]
pub enum PolicyType {
    Random,
    RoundRobin,
    ApproxTree,
}

18
19
20
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
21
22
23
24
25
    #[arg(
        long,
        default_value = "127.0.0.1",
        help = "Host address to bind the server to"
    )]
26
27
    host: String,

28
    #[arg(long, default_value_t = 3001, help = "Port number to listen on")]
29
30
    port: u16,

31
32
33
34
35
    #[arg(
        long,
        value_delimiter = ',',
        help = "Comma-separated list of worker URLs to distribute requests to"
    )]
36
37
    worker_urls: Vec<String>,

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    #[arg(
        long,
        default_value_t = PolicyType::RoundRobin,
        value_enum,
        help = "Load balancing policy to use: random, round_robin, or approx_tree"
    )]
    policy: PolicyType,

    #[arg(
        long,
        requires = "policy",
        required_if_eq("policy", "approx_tree"),
        help = "Path to the tokenizer file, required when using approx_tree policy"
    )]
    tokenizer_path: Option<String>,

    #[arg(
        long,
        default_value = "0.50",
        requires = "policy",
        required_if_eq("policy", "approx_tree"),
        help = "Cache threshold (0.0-1.0) for approx_tree routing. Routes to cached worker if match rate exceeds threshold, otherwise routes to shortest queue worker"
    )]
    cache_threshold: Option<f32>,
}

impl Args {
    fn get_policy_config(&self) -> PolicyConfig {
        match self.policy {
            PolicyType::Random => PolicyConfig::RandomConfig,
            PolicyType::RoundRobin => PolicyConfig::RoundRobinConfig,
            PolicyType::ApproxTree => PolicyConfig::ApproxTreeConfig {
                tokenizer_path: self
                    .tokenizer_path
                    .clone()
                    .expect("tokenizer_path is required for approx_tree policy"),
                cache_threshold: self
                    .cache_threshold
                    .expect("cache_threshold is required for approx_tree policy"),
            },
        }
    }
80
81
82
83
84
}

#[actix_web::main]
async fn main() -> std::io::Result<()> {
    let args = Args::parse();
85
86
    let policy_config = args.get_policy_config();
    server::startup(args.host, args.port, args.worker_urls, policy_config).await
87
}