cache_aware_backward_compat_test.rs 5.17 KB
Newer Older
1
use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType};
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
use sglang_router_rs::policies::{CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy};
use std::collections::HashMap;
use std::sync::Arc;

#[test]
fn test_backward_compatibility_with_empty_model_id() {
    let config = CacheAwareConfig {
        cache_threshold: 0.5,
        balance_abs_threshold: 2,
        balance_rel_threshold: 1.5,
        eviction_interval_secs: 0, // Disable background eviction for testing
        max_tree_size: 100,
    };

    let policy = CacheAwarePolicy::with_config(config);

    // Create workers with empty model_id (simulating existing routers)
19
20
21
    let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
        .worker_type(WorkerType::Regular)
        .build();
22
23
24
25
    // No model_id label - should default to "unknown"

    let mut labels2 = HashMap::new();
    labels2.insert("model_id".to_string(), "unknown".to_string());
26
27
28
29
    let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
        .worker_type(WorkerType::Regular)
        .labels(labels2)
        .build();
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

    // Add workers - should both go to "default" tree
    policy.add_worker(&worker1);
    policy.add_worker(&worker2);

    // Create worker list
    let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(worker1.clone()), Arc::new(worker2.clone())];

    // Select worker - should work without errors
    let selected = policy.select_worker(&workers, Some("test request"));
    assert!(selected.is_some(), "Should select a worker");

    // Remove workers - should work without errors
    policy.remove_worker(&worker1);
    policy.remove_worker(&worker2);
}

#[test]
fn test_mixed_model_ids() {
    let config = CacheAwareConfig {
        cache_threshold: 0.5,
        balance_abs_threshold: 2,
        balance_rel_threshold: 1.5,
        eviction_interval_secs: 0,
        max_tree_size: 100,
    };

    let policy = CacheAwarePolicy::with_config(config);

    // Create workers with different model_id scenarios
60
61
62
    let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
        .worker_type(WorkerType::Regular)
        .build();
63
64
65
66
    // No model_id label - defaults to "unknown" which goes to "default" tree

    let mut labels2 = HashMap::new();
    labels2.insert("model_id".to_string(), "llama-3".to_string());
67
68
69
70
    let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
        .worker_type(WorkerType::Regular)
        .labels(labels2)
        .build();
71
72
73

    let mut labels3 = HashMap::new();
    labels3.insert("model_id".to_string(), "unknown".to_string());
74
75
76
77
    let worker3 = BasicWorkerBuilder::new("http://worker3:8080")
        .worker_type(WorkerType::Regular)
        .labels(labels3)
        .build();
78
79
80

    let mut labels4 = HashMap::new();
    labels4.insert("model_id".to_string(), "llama-3".to_string());
81
82
83
84
    let worker4 = BasicWorkerBuilder::new("http://worker4:8080")
        .worker_type(WorkerType::Regular)
        .labels(labels4)
        .build();
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

    // Add all workers
    policy.add_worker(&worker1);
    policy.add_worker(&worker2);
    policy.add_worker(&worker3);
    policy.add_worker(&worker4);

    // Test selection with default workers only
    let default_workers: Vec<Arc<dyn Worker>> =
        vec![Arc::new(worker1.clone()), Arc::new(worker3.clone())];
    let selected = policy.select_worker(&default_workers, Some("test request"));
    assert!(selected.is_some(), "Should select from default workers");

    // Test selection with specific model workers only
    let llama_workers: Vec<Arc<dyn Worker>> =
        vec![Arc::new(worker2.clone()), Arc::new(worker4.clone())];
    let selected = policy.select_worker(&llama_workers, Some("test request"));
    assert!(selected.is_some(), "Should select from llama-3 workers");

    // Test selection with mixed workers
    let all_workers: Vec<Arc<dyn Worker>> = vec![
        Arc::new(worker1.clone()),
        Arc::new(worker2.clone()),
        Arc::new(worker3.clone()),
        Arc::new(worker4.clone()),
    ];
    let selected = policy.select_worker(&all_workers, Some("test request"));
    assert!(selected.is_some(), "Should select from all workers");
}

#[test]
fn test_remove_worker_by_url_backward_compat() {
    let config = CacheAwareConfig::default();
    let policy = CacheAwarePolicy::with_config(config);

    // Create workers with different model_ids
    let mut labels1 = HashMap::new();
    labels1.insert("model_id".to_string(), "llama-3".to_string());
123
124
125
126
127
128
129
130
    let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
        .worker_type(WorkerType::Regular)
        .labels(labels1)
        .build();

    let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
        .worker_type(WorkerType::Regular)
        .build();
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    // No model_id label - defaults to "unknown"

    // Add workers
    policy.add_worker(&worker1);
    policy.add_worker(&worker2);

    // Remove by URL (backward compatibility method)
    // Should remove from all trees since we don't know the model
    policy.remove_worker_by_url("http://worker1:8080");

    // Verify removal worked
    let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(worker2.clone())];
    let selected = policy.select_worker(&workers, Some("test"));
    assert_eq!(selected, Some(0), "Should only have worker2 left");
}