"docs/ZH_CN/vscode:/vscode.git/clone" did not exist on "ba02a9b8827c5600642ec91aa3f2cadb05380d6d"
request_formats_test.rs 11.2 KB
Newer Older
1
2
3
4
5
mod common;

use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
use reqwest::Client;
use serde_json::json;
6
7
8
use sglang_router_rs::config::{
    CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
};
9
10
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
use std::sync::Arc;
11

12
13
/// Test context that manages mock workers
struct TestContext {
14
    workers: Vec<MockWorker>,
15
    router: Arc<dyn RouterTrait>,
16
17
}

18
impl TestContext {
19
    async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self {
20
        let mut config = RouterConfig {
21
22
23
24
25
            mode: RoutingMode::Regular {
                worker_urls: vec![],
            },
            policy: PolicyConfig::Random,
            host: "127.0.0.1".to_string(),
26
            port: 3003,
27
28
29
30
            max_payload_size: 256 * 1024 * 1024,
            request_timeout_secs: 600,
            worker_startup_timeout_secs: 1,
            worker_startup_check_interval_secs: 1,
31
32
            dp_aware: false,
            api_key: None,
33
34
35
36
            discovery: None,
            metrics: None,
            log_dir: None,
            log_level: None,
37
            request_id_headers: None,
38
39
            max_concurrent_requests: 64,
            cors_allowed_origins: vec![],
40
            retry: RetryConfig::default(),
41
            circuit_breaker: CircuitBreakerConfig::default(),
42
43
            disable_retries: false,
            disable_circuit_breaker: false,
44
            health_check: sglang_router_rs::config::HealthCheckConfig::default(),
45
46
        };

47
48
        let mut workers = Vec::new();
        let mut worker_urls = Vec::new();
49

50
51
52
53
54
55
        for worker_config in worker_configs {
            let mut worker = MockWorker::new(worker_config);
            let url = worker.start().await.unwrap();
            worker_urls.push(url);
            workers.push(worker);
        }
56

57
58
        if !workers.is_empty() {
            tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
59
60
        }

61
        config.mode = RoutingMode::Regular { worker_urls };
62

63
        let app_context = common::create_test_context(config);
64
        let router = RouterFactory::create_router(&app_context).await.unwrap();
65
        let router = Arc::from(router);
66

67
68
69
70
71
        if !workers.is_empty() {
            tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
        }

        Self { workers, router }
72
73
74
    }

    async fn shutdown(mut self) {
75
76
77
        // Small delay to ensure any pending operations complete
        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;

78
79
80
        for worker in &mut self.workers {
            worker.stop().await;
        }
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

        // Another small delay to ensure cleanup completes
        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
    }

    async fn make_request(
        &self,
        endpoint: &str,
        body: serde_json::Value,
    ) -> Result<serde_json::Value, String> {
        let client = Client::new();

        // Get any worker URL for testing
        let worker_urls = self.router.get_worker_urls();
        if worker_urls.is_empty() {
            return Err("No available workers".to_string());
        }

        let worker_url = &worker_urls[0];

        let response = client
102
            .post(format!("{}{}", worker_url, endpoint))
103
104
105
106
107
108
109
110
111
112
113
114
115
            .json(&body)
            .send()
            .await
            .map_err(|e| format!("Request failed: {}", e))?;

        if !response.status().is_success() {
            return Err(format!("Request failed with status: {}", response.status()));
        }

        response
            .json::<serde_json::Value>()
            .await
            .map_err(|e| format!("Failed to parse response: {}", e))
116
117
118
119
    }
}

#[cfg(test)]
120
mod request_format_tests {
121
122
    use super::*;

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    #[tokio::test]
    async fn test_generate_request_formats() {
        let ctx = TestContext::new(vec![MockWorkerConfig {
            port: 19001,
            worker_type: WorkerType::Regular,
            health_status: HealthStatus::Healthy,
            response_delay_ms: 0,
            fail_rate: 0.0,
        }])
        .await;

        // Test 1: Basic text request
        let payload = json!({
            "text": "Hello, world!",
            "stream": false
138
139
        });

140
141
        let result = ctx.make_request("/generate", payload).await;
        assert!(result.is_ok());
142

143
144
145
146
        // Test 2: Request with sampling parameters
        let payload = json!({
            "text": "Tell me a story",
            "sampling_params": {
147
148
                "temperature": 0.7,
                "max_new_tokens": 100,
149
150
151
                "top_p": 0.9
            },
            "stream": false
152
153
        });

154
155
        let result = ctx.make_request("/generate", payload).await;
        assert!(result.is_ok());
156

157
158
159
160
161
162
163
164
        // Test 3: Request with input_ids
        let payload = json!({
            "input_ids": [1, 2, 3, 4, 5],
            "sampling_params": {
                "temperature": 0.0,
                "max_new_tokens": 50
            },
            "stream": false
165
        });
166
167
168
169
170

        let result = ctx.make_request("/generate", payload).await;
        assert!(result.is_ok());

        ctx.shutdown().await;
171
172
    }

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    #[tokio::test]
    async fn test_v1_chat_completions_formats() {
        let ctx = TestContext::new(vec![MockWorkerConfig {
            port: 19002,
            worker_type: WorkerType::Regular,
            health_status: HealthStatus::Healthy,
            response_delay_ms: 0,
            fail_rate: 0.0,
        }])
        .await;

        // Test 1: Basic chat completion
        let payload = json!({
            "model": "test-model",
            "messages": [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": "Hello!"}
            ],
            "stream": false
192
        });
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220

        let result = ctx.make_request("/v1/chat/completions", payload).await;
        assert!(result.is_ok());

        let response = result.unwrap();
        assert!(response.get("choices").is_some());
        assert!(response.get("id").is_some());
        assert_eq!(
            response.get("object").and_then(|v| v.as_str()),
            Some("chat.completion")
        );

        // Test 2: Chat completion with parameters
        let payload = json!({
            "model": "test-model",
            "messages": [
                {"role": "user", "content": "Tell me a joke"}
            ],
            "temperature": 0.8,
            "max_tokens": 150,
            "top_p": 0.95,
            "stream": false
        });

        let result = ctx.make_request("/v1/chat/completions", payload).await;
        assert!(result.is_ok());

        ctx.shutdown().await;
221
222
    }

223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    #[tokio::test]
    async fn test_v1_completions_formats() {
        let ctx = TestContext::new(vec![MockWorkerConfig {
            port: 19003,
            worker_type: WorkerType::Regular,
            health_status: HealthStatus::Healthy,
            response_delay_ms: 0,
            fail_rate: 0.0,
        }])
        .await;

        // Test 1: Basic completion
        let payload = json!({
            "model": "test-model",
            "prompt": "Once upon a time",
            "max_tokens": 50,
            "stream": false
        });
241

242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
        let result = ctx.make_request("/v1/completions", payload).await;
        assert!(result.is_ok());

        let response = result.unwrap();
        assert!(response.get("choices").is_some());
        assert_eq!(
            response.get("object").and_then(|v| v.as_str()),
            Some("text_completion")
        );

        // Test 2: Completion with array prompt
        let payload = json!({
            "model": "test-model",
            "prompt": ["First prompt", "Second prompt"],
            "temperature": 0.5,
            "stream": false
258
259
        });

260
261
262
263
264
265
266
267
268
269
        let result = ctx.make_request("/v1/completions", payload).await;
        assert!(result.is_ok());

        // Test 3: Completion with logprobs
        let payload = json!({
            "model": "test-model",
            "prompt": "The capital of France is",
            "max_tokens": 10,
            "logprobs": 5,
            "stream": false
270
        });
271
272
273
274
275

        let result = ctx.make_request("/v1/completions", payload).await;
        assert!(result.is_ok());

        ctx.shutdown().await;
276
277
    }

278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
    #[tokio::test]
    async fn test_batch_requests() {
        let ctx = TestContext::new(vec![MockWorkerConfig {
            port: 19004,
            worker_type: WorkerType::Regular,
            health_status: HealthStatus::Healthy,
            response_delay_ms: 0,
            fail_rate: 0.0,
        }])
        .await;

        // Test batch text generation
        let payload = json!({
            "text": ["First text", "Second text", "Third text"],
            "sampling_params": {
                "temperature": 0.7,
                "max_new_tokens": 50
            },
            "stream": false
297
298
        });

299
300
301
302
303
304
305
        let result = ctx.make_request("/generate", payload).await;
        assert!(result.is_ok());

        // Test batch with input_ids
        let payload = json!({
            "input_ids": [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
            "stream": false
306
        });
307
308
309
310
311

        let result = ctx.make_request("/generate", payload).await;
        assert!(result.is_ok());

        ctx.shutdown().await;
312
313
    }

314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
    #[tokio::test]
    async fn test_special_parameters() {
        let ctx = TestContext::new(vec![MockWorkerConfig {
            port: 19005,
            worker_type: WorkerType::Regular,
            health_status: HealthStatus::Healthy,
            response_delay_ms: 0,
            fail_rate: 0.0,
        }])
        .await;

        // Test with return_logprob
        let payload = json!({
            "text": "Test",
            "return_logprob": true,
            "stream": false
330
331
        });

332
333
        let result = ctx.make_request("/generate", payload).await;
        assert!(result.is_ok());
334

335
336
337
338
339
340
341
342
        // Test with json_schema
        let payload = json!({
            "text": "Generate JSON",
            "sampling_params": {
                "temperature": 0.0,
                "json_schema": "$$ANY$$"
            },
            "stream": false
343
344
        });

345
346
347
348
349
350
351
352
353
354
355
356
        let result = ctx.make_request("/generate", payload).await;
        assert!(result.is_ok());

        // Test with ignore_eos
        let payload = json!({
            "text": "Continue forever",
            "sampling_params": {
                "temperature": 0.7,
                "max_new_tokens": 100,
                "ignore_eos": true
            },
            "stream": false
357
        });
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383

        let result = ctx.make_request("/generate", payload).await;
        assert!(result.is_ok());

        ctx.shutdown().await;
    }

    #[tokio::test]
    async fn test_error_handling() {
        let ctx = TestContext::new(vec![MockWorkerConfig {
            port: 19006,
            worker_type: WorkerType::Regular,
            health_status: HealthStatus::Healthy,
            response_delay_ms: 0,
            fail_rate: 0.0,
        }])
        .await;

        // Test with empty body - should still work with mock worker
        let payload = json!({});

        let result = ctx.make_request("/generate", payload).await;
        // Mock worker accepts empty body
        assert!(result.is_ok());

        ctx.shutdown().await;
384
385
    }
}