streaming_tests.rs 11 KB
Newer Older
1
2
3
mod common;

use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
4
use futures_util::StreamExt;
5
6
use reqwest::Client;
use serde_json::json;
7
8
9
use sglang_router_rs::config::{
    CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
};
10
11
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
use std::sync::Arc;
12

13
14
/// Test context that manages mock workers
struct TestContext {
15
    workers: Vec<MockWorker>,
16
    router: Arc<dyn RouterTrait>,
17
18
}

19
impl TestContext {
20
    async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self {
21
        let mut config = RouterConfig {
22
23
24
25
26
            mode: RoutingMode::Regular {
                worker_urls: vec![],
            },
            policy: PolicyConfig::Random,
            host: "127.0.0.1".to_string(),
27
            port: 3004,
28
29
30
31
            max_payload_size: 256 * 1024 * 1024,
            request_timeout_secs: 600,
            worker_startup_timeout_secs: 1,
            worker_startup_check_interval_secs: 1,
32
33
            dp_aware: false,
            api_key: None,
34
35
36
37
            discovery: None,
            metrics: None,
            log_dir: None,
            log_level: None,
38
            request_id_headers: None,
39
            max_concurrent_requests: 64,
40
41
42
            queue_size: 0,
            queue_timeout_secs: 60,
            rate_limit_tokens_per_second: None,
43
            cors_allowed_origins: vec![],
44
            retry: RetryConfig::default(),
45
            circuit_breaker: CircuitBreakerConfig::default(),
46
47
            disable_retries: false,
            disable_circuit_breaker: false,
48
            health_check: sglang_router_rs::config::HealthCheckConfig::default(),
49
            enable_igw: false,
50
51
        };

52
53
        let mut workers = Vec::new();
        let mut worker_urls = Vec::new();
54

55
56
57
58
59
60
        for worker_config in worker_configs {
            let mut worker = MockWorker::new(worker_config);
            let url = worker.start().await.unwrap();
            worker_urls.push(url);
            workers.push(worker);
        }
61

62
63
        if !workers.is_empty() {
            tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
64
65
        }

66
        config.mode = RoutingMode::Regular { worker_urls };
67

68
        let app_context = common::create_test_context(config);
69
        let router = RouterFactory::create_router(&app_context).await.unwrap();
70
        let router = Arc::from(router);
71

72
73
74
75
76
        if !workers.is_empty() {
            tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
        }

        Self { workers, router }
77
78
79
    }

    async fn shutdown(mut self) {
80
81
82
        // Small delay to ensure any pending operations complete
        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;

83
84
85
86
        for worker in &mut self.workers {
            worker.stop().await;
        }

87
88
89
        // Another small delay to ensure cleanup completes
        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
    }
90

91
92
93
94
95
96
97
98
99
100
101
    async fn make_streaming_request(
        &self,
        endpoint: &str,
        body: serde_json::Value,
    ) -> Result<Vec<String>, String> {
        let client = Client::new();

        // Get any worker URL for testing
        let worker_urls = self.router.get_worker_urls();
        if worker_urls.is_empty() {
            return Err("No available workers".to_string());
102
103
        }

104
        let worker_url = &worker_urls[0];
105

106
        let response = client
107
            .post(format!("{}{}", worker_url, endpoint))
108
109
110
111
            .json(&body)
            .send()
            .await
            .map_err(|e| format!("Request failed: {}", e))?;
112

113
114
115
        if !response.status().is_success() {
            return Err(format!("Request failed with status: {}", response.status()));
        }
116

117
118
119
120
121
122
        // Check if it's a streaming response
        let content_type = response
            .headers()
            .get("content-type")
            .and_then(|v| v.to_str().ok())
            .unwrap_or("");
123

124
125
126
        if !content_type.contains("text/event-stream") {
            return Err("Response is not a stream".to_string());
        }
127

128
129
        let mut stream = response.bytes_stream();
        let mut events = Vec::new();
130

131
132
133
134
        while let Some(chunk) = stream.next().await {
            if let Ok(bytes) = chunk {
                let text = String::from_utf8_lossy(&bytes);
                for line in text.lines() {
135
136
                    if let Some(stripped) = line.strip_prefix("data: ") {
                        events.push(stripped.to_string());
137
138
139
140
                    }
                }
            }
        }
141

142
        Ok(events)
143
    }
144
}
145

146
147
148
#[cfg(test)]
mod streaming_tests {
    use super::*;
149

150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    #[tokio::test]
    async fn test_generate_streaming() {
        let ctx = TestContext::new(vec![MockWorkerConfig {
            port: 20001,
            worker_type: WorkerType::Regular,
            health_status: HealthStatus::Healthy,
            response_delay_ms: 10,
            fail_rate: 0.0,
        }])
        .await;

        let payload = json!({
            "text": "Stream test",
            "stream": true,
            "sampling_params": {
                "temperature": 0.7,
                "max_new_tokens": 10
            }
168
169
        });

170
171
        let result = ctx.make_streaming_request("/generate", payload).await;
        assert!(result.is_ok());
172

173
174
175
176
        let events = result.unwrap();
        // Should have at least one data chunk and [DONE]
        assert!(events.len() >= 2);
        assert_eq!(events.last().unwrap(), "[DONE]");
177

178
179
        ctx.shutdown().await;
    }
180

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    #[tokio::test]
    async fn test_v1_chat_completions_streaming() {
        let ctx = TestContext::new(vec![MockWorkerConfig {
            port: 20002,
            worker_type: WorkerType::Regular,
            health_status: HealthStatus::Healthy,
            response_delay_ms: 10,
            fail_rate: 0.0,
        }])
        .await;

        let payload = json!({
            "model": "test-model",
            "messages": [
                {"role": "user", "content": "Count to 3"}
            ],
            "stream": true,
            "max_tokens": 20
199
200
        });

201
202
        let result = ctx
            .make_streaming_request("/v1/chat/completions", payload)
203
            .await;
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        assert!(result.is_ok());

        let events = result.unwrap();
        assert!(events.len() >= 2); // At least one chunk + [DONE]

        // Verify events are valid JSON (except [DONE])
        for event in &events {
            if event != "[DONE]" {
                let parsed: Result<serde_json::Value, _> = serde_json::from_str(event);
                assert!(parsed.is_ok(), "Invalid JSON in SSE event: {}", event);

                let json = parsed.unwrap();
                assert_eq!(
                    json.get("object").and_then(|v| v.as_str()),
                    Some("chat.completion.chunk")
                );
            }
        }
222

223
        ctx.shutdown().await;
224
225
    }

226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    #[tokio::test]
    async fn test_v1_completions_streaming() {
        let ctx = TestContext::new(vec![MockWorkerConfig {
            port: 20003,
            worker_type: WorkerType::Regular,
            health_status: HealthStatus::Healthy,
            response_delay_ms: 10,
            fail_rate: 0.0,
        }])
        .await;

        let payload = json!({
            "model": "test-model",
            "prompt": "Once upon a time",
            "stream": true,
            "max_tokens": 15
242
243
        });

244
245
        let result = ctx.make_streaming_request("/v1/completions", payload).await;
        assert!(result.is_ok());
246

247
248
        let events = result.unwrap();
        assert!(events.len() >= 2); // At least one chunk + [DONE]
249

250
251
        ctx.shutdown().await;
    }
252

253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
    #[tokio::test]
    async fn test_streaming_with_error() {
        let ctx = TestContext::new(vec![MockWorkerConfig {
            port: 20004,
            worker_type: WorkerType::Regular,
            health_status: HealthStatus::Healthy,
            response_delay_ms: 0,
            fail_rate: 1.0, // Always fail
        }])
        .await;

        let payload = json!({
            "text": "This should fail",
            "stream": true
        });
268

269
270
271
        let result = ctx.make_streaming_request("/generate", payload).await;
        // With fail_rate: 1.0, the request should fail
        assert!(result.is_err());
272

273
274
        ctx.shutdown().await;
    }
275

276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
    #[tokio::test]
    async fn test_streaming_timeouts() {
        let ctx = TestContext::new(vec![MockWorkerConfig {
            port: 20005,
            worker_type: WorkerType::Regular,
            health_status: HealthStatus::Healthy,
            response_delay_ms: 100, // Slow response
            fail_rate: 0.0,
        }])
        .await;

        let payload = json!({
            "text": "Slow stream",
            "stream": true,
            "sampling_params": {
                "max_new_tokens": 5
292
293
294
            }
        });

295
296
297
        let start = std::time::Instant::now();
        let result = ctx.make_streaming_request("/generate", payload).await;
        let elapsed = start.elapsed();
298

299
300
        assert!(result.is_ok());
        let events = result.unwrap();
301

302
303
304
        // Should have received multiple chunks over time
        assert!(!events.is_empty());
        assert!(elapsed.as_millis() >= 100); // At least one delay
305

306
        ctx.shutdown().await;
307
308
    }

309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
    #[tokio::test]
    async fn test_batch_streaming() {
        let ctx = TestContext::new(vec![MockWorkerConfig {
            port: 20006,
            worker_type: WorkerType::Regular,
            health_status: HealthStatus::Healthy,
            response_delay_ms: 10,
            fail_rate: 0.0,
        }])
        .await;

        // Batch request with streaming
        let payload = json!({
            "text": ["First", "Second", "Third"],
            "stream": true,
            "sampling_params": {
                "max_new_tokens": 5
            }
327
328
        });

329
330
        let result = ctx.make_streaming_request("/generate", payload).await;
        assert!(result.is_ok());
331

332
333
334
        let events = result.unwrap();
        // Should have multiple events for batch
        assert!(events.len() >= 4); // At least 3 responses + [DONE]
335

336
        ctx.shutdown().await;
337
338
    }

339
340
341
342
343
344
345
346
347
348
    #[tokio::test]
    async fn test_sse_format_parsing() {
        // Test SSE format parsing
        let parse_sse_chunk = |chunk: &[u8]| -> Vec<String> {
            let text = String::from_utf8_lossy(chunk);
            text.lines()
                .filter(|line| line.starts_with("data: "))
                .map(|line| line[6..].to_string())
                .collect()
        };
349

350
351
352
        let sse_data =
            b"data: {\"text\":\"Hello\"}\n\ndata: {\"text\":\" world\"}\n\ndata: [DONE]\n\n";
        let events = parse_sse_chunk(sse_data);
353

354
355
356
357
        assert_eq!(events.len(), 3);
        assert_eq!(events[0], "{\"text\":\"Hello\"}");
        assert_eq!(events[1], "{\"text\":\" world\"}");
        assert_eq!(events[2], "[DONE]");
358

359
360
361
        // Test with mixed content
        let mixed = b"event: message\ndata: {\"test\":true}\n\n: comment\ndata: [DONE]\n\n";
        let events = parse_sse_chunk(mixed);
362

363
364
365
        assert_eq!(events.len(), 2);
        assert_eq!(events[0], "{\"test\":true}");
        assert_eq!(events[1], "[DONE]");
366
367
    }
}