policy_registry_integration.rs 5.29 KB
Newer Older
1
2
//! Integration tests for PolicyRegistry with RouterManager

3
4
5
6
7
8
use std::{collections::HashMap, sync::Arc};

use sglang_router_rs::{
    config::PolicyConfig, core::WorkerRegistry, policies::PolicyRegistry,
    protocols::worker_spec::WorkerConfigRequest, routers::router_manager::RouterManager,
};
9
10
11
12

#[tokio::test]
async fn test_policy_registry_with_router_manager() {
    // Create HTTP client
13
    let _client = reqwest::Client::new();
14
15
16
17
18
19

    // Create shared registries
    let worker_registry = Arc::new(WorkerRegistry::new());
    let policy_registry = Arc::new(PolicyRegistry::new(PolicyConfig::RoundRobin));

    // Create RouterManager with shared registries
20
    let _router_manager = RouterManager::new(worker_registry.clone());
21
22
23
24
25
26
27
28

    // Add first worker for llama-3 with cache_aware policy hint
    let mut labels1 = HashMap::new();
    labels1.insert("policy".to_string(), "cache_aware".to_string());

    let _worker1_config = WorkerConfigRequest {
        url: "http://worker1:8000".to_string(),
        model_id: Some("llama-3".to_string()),
29
        api_key: Some("test_api_key".to_string()),
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
        worker_type: None,
        priority: None,
        cost: None,
        labels: labels1,
        bootstrap_port: None,
        tokenizer_path: None,
        reasoning_parser: None,
        tool_parser: None,
        chat_template: None,
    };

    // This would normally connect to a real worker, but for testing we'll just verify the structure
    // In a real test, we'd need to mock the worker or use a test server

    let _llama_policy = policy_registry.get_policy("llama-3");
    // After first worker is added, llama-3 should have a policy

    // Add second worker for llama-3 with different policy hint (should be ignored)
    let mut labels2 = HashMap::new();
    labels2.insert("policy".to_string(), "random".to_string());

    let _worker2_config = WorkerConfigRequest {
        url: "http://worker2:8000".to_string(),
        model_id: Some("llama-3".to_string()),
54
        api_key: Some("test_api_key".to_string()),
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        worker_type: None,
        priority: None,
        cost: None,
        labels: labels2,
        bootstrap_port: None,
        tokenizer_path: None,
        reasoning_parser: None,
        tool_parser: None,
        chat_template: None,
    };

    // The second worker should use the same policy as the first (cache_aware)

    // Add worker for different model (gpt-4) with random policy
    let mut labels3 = HashMap::new();
    labels3.insert("policy".to_string(), "random".to_string());

    let _worker3_config = WorkerConfigRequest {
        url: "http://worker3:8000".to_string(),
        model_id: Some("gpt-4".to_string()),
75
        api_key: Some("test_api_key".to_string()),
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        worker_type: None,
        priority: None,
        cost: None,
        labels: labels3,
        bootstrap_port: None,
        tokenizer_path: None,
        reasoning_parser: None,
        tool_parser: None,
        chat_template: None,
    };

    let _gpt_policy = policy_registry.get_policy("gpt-4");

    // When we remove both llama-3 workers, the policy should be cleaned up

    println!("PolicyRegistry integration test structure created");
    println!("Note: This test requires mocking or test servers to fully execute");
}

#[test]
fn test_policy_registry_cleanup() {
97
    use sglang_router_rs::{config::PolicyConfig, policies::PolicyRegistry};
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

    let registry = PolicyRegistry::new(PolicyConfig::RoundRobin);

    // Add workers for a model
    let policy1 = registry.on_worker_added("model-1", Some("cache_aware"));
    assert_eq!(policy1.name(), "cache_aware");

    // Second worker uses existing policy
    let policy2 = registry.on_worker_added("model-1", Some("random"));
    assert_eq!(policy2.name(), "cache_aware"); // Should still be cache_aware

    assert!(registry.get_policy("model-1").is_some());

    // Remove first worker - policy should remain
    registry.on_worker_removed("model-1");
    assert!(registry.get_policy("model-1").is_some());

    // Remove second worker - policy should be cleaned up
    registry.on_worker_removed("model-1");
    assert!(registry.get_policy("model-1").is_none());

    println!("✓ PolicyRegistry cleanup test passed");
}

#[test]
fn test_policy_registry_multiple_models() {
124
    use sglang_router_rs::{config::PolicyConfig, policies::PolicyRegistry};
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

    let registry = PolicyRegistry::new(PolicyConfig::RoundRobin);

    // Add workers for different models with different policies
    let llama_policy = registry.on_worker_added("llama-3", Some("cache_aware"));
    let gpt_policy = registry.on_worker_added("gpt-4", Some("random"));
    let mistral_policy = registry.on_worker_added("mistral", None); // Uses default

    assert_eq!(llama_policy.name(), "cache_aware");
    assert_eq!(gpt_policy.name(), "random");
    assert_eq!(mistral_policy.name(), "round_robin"); // Default

    assert!(registry.get_policy("llama-3").is_some());
    assert!(registry.get_policy("gpt-4").is_some());
    assert!(registry.get_policy("mistral").is_some());

    // Get all mappings
    let mappings = registry.get_all_mappings();
    assert_eq!(mappings.len(), 3);
    assert_eq!(mappings.get("llama-3").unwrap(), "cache_aware");
    assert_eq!(mappings.get("gpt-4").unwrap(), "random");
    assert_eq!(mappings.get("mistral").unwrap(), "round_robin");

    println!("✓ PolicyRegistry multiple models test passed");
}