cache_aware_backward_compat_test.rs 5.18 KB
Newer Older
1
2
3
4
5
6
use std::{collections::HashMap, sync::Arc};

use sglang_router_rs::{
    core::{BasicWorkerBuilder, Worker, WorkerType},
    policies::{CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy},
};
7
8
9
10
11
12
13
14
15
16
17
18
19
20

#[test]
fn test_backward_compatibility_with_empty_model_id() {
    let config = CacheAwareConfig {
        cache_threshold: 0.5,
        balance_abs_threshold: 2,
        balance_rel_threshold: 1.5,
        eviction_interval_secs: 0, // Disable background eviction for testing
        max_tree_size: 100,
    };

    let policy = CacheAwarePolicy::with_config(config);

    // Create workers with empty model_id (simulating existing routers)
21
22
    let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
        .worker_type(WorkerType::Regular)
23
        .api_key("test_api_key")
24
        .build();
25
26
27
28
    // No model_id label - should default to "unknown"

    let mut labels2 = HashMap::new();
    labels2.insert("model_id".to_string(), "unknown".to_string());
29
30
    let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
        .worker_type(WorkerType::Regular)
31
        .api_key("test_api_key")
32
33
        .labels(labels2)
        .build();
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

    // Add workers - should both go to "default" tree
    policy.add_worker(&worker1);
    policy.add_worker(&worker2);

    // Create worker list
    let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(worker1.clone()), Arc::new(worker2.clone())];

    // Select worker - should work without errors
    let selected = policy.select_worker(&workers, Some("test request"));
    assert!(selected.is_some(), "Should select a worker");

    // Remove workers - should work without errors
    policy.remove_worker(&worker1);
    policy.remove_worker(&worker2);
}

#[test]
fn test_mixed_model_ids() {
    let config = CacheAwareConfig {
        cache_threshold: 0.5,
        balance_abs_threshold: 2,
        balance_rel_threshold: 1.5,
        eviction_interval_secs: 0,
        max_tree_size: 100,
    };

    let policy = CacheAwarePolicy::with_config(config);

    // Create workers with different model_id scenarios
64
65
    let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
        .worker_type(WorkerType::Regular)
66
        .api_key("test_api_key")
67
        .build();
68
69
70
71
    // No model_id label - defaults to "unknown" which goes to "default" tree

    let mut labels2 = HashMap::new();
    labels2.insert("model_id".to_string(), "llama-3".to_string());
72
73
74
    let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
        .worker_type(WorkerType::Regular)
        .labels(labels2)
75
        .api_key("test_api_key")
76
        .build();
77
78
79

    let mut labels3 = HashMap::new();
    labels3.insert("model_id".to_string(), "unknown".to_string());
80
81
82
83
    let worker3 = BasicWorkerBuilder::new("http://worker3:8080")
        .worker_type(WorkerType::Regular)
        .labels(labels3)
        .build();
84
85
86

    let mut labels4 = HashMap::new();
    labels4.insert("model_id".to_string(), "llama-3".to_string());
87
88
89
90
    let worker4 = BasicWorkerBuilder::new("http://worker4:8080")
        .worker_type(WorkerType::Regular)
        .labels(labels4)
        .build();
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

    // Add all workers
    policy.add_worker(&worker1);
    policy.add_worker(&worker2);
    policy.add_worker(&worker3);
    policy.add_worker(&worker4);

    let default_workers: Vec<Arc<dyn Worker>> =
        vec![Arc::new(worker1.clone()), Arc::new(worker3.clone())];
    let selected = policy.select_worker(&default_workers, Some("test request"));
    assert!(selected.is_some(), "Should select from default workers");

    let llama_workers: Vec<Arc<dyn Worker>> =
        vec![Arc::new(worker2.clone()), Arc::new(worker4.clone())];
    let selected = policy.select_worker(&llama_workers, Some("test request"));
    assert!(selected.is_some(), "Should select from llama-3 workers");

    let all_workers: Vec<Arc<dyn Worker>> = vec![
        Arc::new(worker1.clone()),
        Arc::new(worker2.clone()),
        Arc::new(worker3.clone()),
        Arc::new(worker4.clone()),
    ];
    let selected = policy.select_worker(&all_workers, Some("test request"));
    assert!(selected.is_some(), "Should select from all workers");
}

#[test]
fn test_remove_worker_by_url_backward_compat() {
    let config = CacheAwareConfig::default();
    let policy = CacheAwarePolicy::with_config(config);

    // Create workers with different model_ids
    let mut labels1 = HashMap::new();
    labels1.insert("model_id".to_string(), "llama-3".to_string());
126
127
128
    let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
        .worker_type(WorkerType::Regular)
        .labels(labels1)
129
        .api_key("test_api_key")
130
131
132
133
        .build();

    let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
        .worker_type(WorkerType::Regular)
134
        .api_key("test_api_key")
135
        .build();
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    // No model_id label - defaults to "unknown"

    // Add workers
    policy.add_worker(&worker1);
    policy.add_worker(&worker2);

    // Remove by URL (backward compatibility method)
    // Should remove from all trees since we don't know the model
    policy.remove_worker_by_url("http://worker1:8080");

    let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(worker2.clone())];
    let selected = policy.select_worker(&workers, Some("test"));
    assert_eq!(selected, Some(0), "Should only have worker2 left");
}