http-service.rs 28.1 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
// SPDX-License-Identifier: Apache-2.0

use anyhow::Error;
use async_stream::stream;
Neelay Shah's avatar
Neelay Shah committed
6
use dynamo_llm::protocols::{
7
    Annotated,
Ryan Olson's avatar
Ryan Olson committed
8
9
    codec::SseLineCodec,
    convert_sse_stream,
10
    openai::{
11
        chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
12
        completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
13
14
    },
};
15
use dynamo_llm::{
16
17
18
19
20
    http::service::{
        Metrics,
        error::HttpError,
        metrics::{Endpoint, ErrorType, RequestType, Status},
        service_v2::HttpService,
21
22
23
    },
    model_card::ModelDeploymentCard,
};
24
use dynamo_runtime::metrics::prometheus_names::{frontend_service, name_prefix};
Neelay Shah's avatar
Neelay Shah committed
25
use dynamo_runtime::{
26
    CancellationToken,
27
    pipeline::{
28
        AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream, SingleIn, async_trait,
29
30
    },
};
31
use futures::StreamExt;
32
use prometheus::{Registry, proto::MetricType};
33
use reqwest::StatusCode;
Ryan Olson's avatar
Ryan Olson committed
34
35
36
use std::{io::Cursor, sync::Arc};
use tokio::time::timeout;
use tokio_util::codec::FramedRead;
37

38
39
40
41
#[path = "common/ports.rs"]
mod ports;
use ports::get_random_port;

42
43
struct CounterEngine {}

Ryan Olson's avatar
Ryan Olson committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
// Add a new long-running test engine
struct LongRunningEngine {
    delay_ms: u64,
    cancelled: Arc<std::sync::atomic::AtomicBool>,
}

impl LongRunningEngine {
    fn new(delay_ms: u64) -> Self {
        Self {
            delay_ms,
            cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
        }
    }

    fn was_cancelled(&self) -> bool {
        self.cancelled.load(std::sync::atomic::Ordering::Acquire)
    }
}

63
64
65
#[async_trait]
impl
    AsyncEngine<
66
        SingleIn<NvCreateChatCompletionRequest>,
67
        ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
68
69
70
71
72
        Error,
    > for CounterEngine
{
    async fn generate(
        &self,
73
        request: SingleIn<NvCreateChatCompletionRequest>,
74
    ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
75
76
77
        let (request, context) = request.transfer(());
        let ctx = context.context();

Paul Hendricks's avatar
Paul Hendricks committed
78
        // ALLOW: max_tokens is deprecated in favor of completion_usage_tokens
Ryan Olson's avatar
Ryan Olson committed
79
        #[allow(deprecated)]
Paul Hendricks's avatar
Paul Hendricks committed
80
        let max_tokens = request.inner.max_tokens.unwrap_or(0) as u64;
81

82
        // let generator = NvCreateChatCompletionStreamResponse::generator(request.model.clone());
83
        let mut generator = request.response_generator(ctx.id().to_string());
84
85
86
87

        let stream = stream! {
            tokio::time::sleep(std::time::Duration::from_millis(max_tokens)).await;
            for i in 0..10 {
88
                let output = generator.create_choice(i, Some(format!("choice {i}")), None, None, None);
Paul Hendricks's avatar
Paul Hendricks committed
89
90

                yield Annotated::from_data(output);
91
92
93
94
95
96
97
            }
        };

        Ok(ResponseStream::new(Box::pin(stream), ctx))
    }
}

Ryan Olson's avatar
Ryan Olson committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#[async_trait]
impl
    AsyncEngine<
        SingleIn<NvCreateChatCompletionRequest>,
        ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
        Error,
    > for LongRunningEngine
{
    async fn generate(
        &self,
        request: SingleIn<NvCreateChatCompletionRequest>,
    ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
        let (_request, context) = request.transfer(());
        let ctx = context.context();

        tracing::info!(
            "LongRunningEngine: Starting generation with {}ms delay",
            self.delay_ms
        );

        let cancelled_flag = self.cancelled.clone();
        let delay_ms = self.delay_ms;

        let ctx_clone = ctx.clone();
        let stream = async_stream::stream! {

            // the stream can be dropped or it can be cancelled
            // either way we consider this a cancellation
            cancelled_flag.store(true, std::sync::atomic::Ordering::SeqCst);

            tokio::select! {
                _ = tokio::time::sleep(std::time::Duration::from_millis(delay_ms)) => {
                    // the stream went to completion
                    cancelled_flag.store(false, std::sync::atomic::Ordering::SeqCst);

                }
                _ = ctx_clone.stopped() => {
                    cancelled_flag.store(true, std::sync::atomic::Ordering::SeqCst);
                }
            }

            yield Annotated::<NvCreateChatCompletionStreamResponse>::from_annotation("event.dynamo.test.sentinel", &"DONE".to_string()).expect("Failed to create annotated response");
        };

        Ok(ResponseStream::new(Box::pin(stream), ctx))
    }
}

146
147
148
149
150
struct AlwaysFailEngine {}

#[async_trait]
impl
    AsyncEngine<
151
        SingleIn<NvCreateChatCompletionRequest>,
152
        ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
153
154
155
156
157
        Error,
    > for AlwaysFailEngine
{
    async fn generate(
        &self,
158
        _request: SingleIn<NvCreateChatCompletionRequest>,
159
    ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
160
161
162
163
164
165
166
167
        Err(HttpError {
            code: 403,
            message: "Always fail".to_string(),
        })?
    }
}

#[async_trait]
168
169
170
171
172
173
impl
    AsyncEngine<
        SingleIn<NvCreateCompletionRequest>,
        ManyOut<Annotated<NvCreateCompletionResponse>>,
        Error,
    > for AlwaysFailEngine
174
175
176
{
    async fn generate(
        &self,
177
        _request: SingleIn<NvCreateCompletionRequest>,
178
    ) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
179
180
181
182
183
184
185
186
        Err(HttpError {
            code: 401,
            message: "Always fail".to_string(),
        })?
    }
}

fn compare_counter(
187
    metrics: &Metrics,
188
189
190
191
    model: &str,
    endpoint: &Endpoint,
    request_type: &RequestType,
    status: &Status,
192
    error_type: &ErrorType,
193
194
195
    expected: u64,
) {
    assert_eq!(
196
        metrics.get_request_counter(model, endpoint, request_type, status, error_type),
197
        expected,
198
        "model: {}, endpoint: {:?}, request_type: {:?}, status: {:?}, error_type: {:?}",
199
200
201
        model,
        endpoint.as_str(),
        request_type.as_str(),
202
203
        status.as_str(),
        error_type.as_str()
204
205
206
207
208
209
210
    );
}

fn compute_index(endpoint: &Endpoint, request_type: &RequestType, status: &Status) -> usize {
    let endpoint = match endpoint {
        Endpoint::Completions => 0,
        Endpoint::ChatCompletions => 1,
211
        Endpoint::Embeddings => todo!(),
212
        Endpoint::Responses => todo!(),
213
        Endpoint::AnthropicMessages => todo!(),
214
        Endpoint::Tensor => todo!(),
215
        Endpoint::Images => todo!(),
216
        Endpoint::Videos => todo!(),
217
        Endpoint::Audios => todo!(),
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    };

    let request_type = match request_type {
        RequestType::Unary => 0,
        RequestType::Stream => 1,
    };

    let status = match status {
        Status::Success => 0,
        Status::Error => 1,
    };

    endpoint * 4 + request_type * 2 + status
}

233
fn compare_counters(metrics: &Metrics, model: &str, expected: &[u64; 8]) {
234
235
236
237
    for endpoint in &[Endpoint::Completions, Endpoint::ChatCompletions] {
        for request_type in &[RequestType::Unary, RequestType::Stream] {
            for status in &[Status::Success, Status::Error] {
                let index = compute_index(endpoint, request_type, status);
238
239
240
241
                let error_type = match status {
                    Status::Success => &ErrorType::None,
                    Status::Error => &ErrorType::Validation, // Test engines return 4xx errors
                };
242
                compare_counter(
243
                    metrics,
244
245
246
247
                    model,
                    endpoint,
                    request_type,
                    status,
248
                    error_type,
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
                    expected[index],
                );
            }
        }
    }
}

fn inc_counter(
    endpoint: Endpoint,
    request_type: RequestType,
    status: Status,
    expected: &mut [u64; 8],
) {
    let index = compute_index(&endpoint, &request_type, &status);
    expected[index] += 1;
}

Paul Hendricks's avatar
Paul Hendricks committed
266
#[allow(deprecated)]
267
268
#[tokio::test]
async fn test_http_service() {
269
    let port = get_random_port().await;
270
    let service = HttpService::builder()
271
        .port(port)
272
273
274
275
        .enable_chat_endpoints(true)
        .enable_cmpl_endpoints(true)
        .build()
        .unwrap();
276
277
    let state = service.state_clone();
    let manager = state.manager();
278
279
280
281
282

    let token = CancellationToken::new();
    let cancel_token = token.clone();
    let task = tokio::spawn(async move { service.run(token.clone()).await });

283
284
285
    // Wait for the service to be ready before proceeding
    wait_for_service_ready(port).await;

286
287
    let registry = Registry::new();

288
289
    // TODO: Shouldn't this test know the card before it registers a model?
    let card = ModelDeploymentCard::with_name_only("foo");
290
    let counter = Arc::new(CounterEngine {});
291
    let result = manager.add_chat_completions_model("foo", card.mdcsum(), counter);
292
293
294
    assert!(result.is_ok());

    let failure = Arc::new(AlwaysFailEngine {});
295
296
    let card = ModelDeploymentCard::with_name_only("bar");
    let result = manager.add_chat_completions_model("bar", card.mdcsum(), failure.clone());
297
298
    assert!(result.is_ok());

299
    let result = manager.add_completions_model("bar", card.mdcsum(), failure);
300
301
    assert!(result.is_ok());

302
    let metrics = state.metrics_clone();
303
304
305
306
307
    metrics.register(&registry).unwrap();

    let mut foo_counters = [0u64; 8];
    let mut bar_counters = [0u64; 8];

308
309
    compare_counters(&metrics, "foo", &foo_counters);
    compare_counters(&metrics, "bar", &bar_counters);
310
311
312

    let client = reqwest::Client::new();

313
314
315
    let message = dynamo_protocols::types::ChatCompletionRequestMessage::User(
        dynamo_protocols::types::ChatCompletionRequestUserMessage {
            content: dynamo_protocols::types::ChatCompletionRequestUserMessageContent::Text(
Paul Hendricks's avatar
Paul Hendricks committed
316
317
318
319
320
321
                "hi".to_string(),
            ),
            name: None,
        },
    );

322
    let mut request = dynamo_protocols::types::CreateChatCompletionRequestArgs::default()
323
        .model("foo")
Paul Hendricks's avatar
Paul Hendricks committed
324
        .messages(vec![message])
325
        .build()
Paul Hendricks's avatar
Paul Hendricks committed
326
327
328
329
330
331
332
        .expect("Failed to build request");

    // let mut request = ChatCompletionRequest::builder()
    //     .model("foo")
    //     .add_user_message("hi")
    //     .build()
    //     .unwrap();
333
334
335

    // ==== ChatCompletions / Stream / Success ====
    request.stream = Some(true);
Paul Hendricks's avatar
Paul Hendricks committed
336
337

    // ALLOW: max_tokens is deprecated in favor of completion_usage_tokens
338
339
340
    request.max_tokens = Some(3000);

    let response = client
341
        .post(format!("http://localhost:{}/v1/chat/completions", port))
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
        .json(&request)
        .send()
        .await
        .unwrap();

    assert!(response.status().is_success(), "{:?}", response);

    tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await;
    assert_eq!(metrics.get_inflight_count("foo"), 1);

    // process byte stream
    let _ = response.bytes().await.unwrap();

    inc_counter(
        Endpoint::ChatCompletions,
        RequestType::Stream,
        Status::Success,
        &mut foo_counters,
    );
361
362
    compare_counters(&metrics, "foo", &foo_counters);
    compare_counters(&metrics, "bar", &bar_counters);
363
364
365
366
367

    // check registry and look or the request duration histogram
    let families = registry.gather();
    let histogram_metric_family = families
        .into_iter()
368
369
370
371
372
373
374
375
        .find(|m| {
            m.get_name()
                == format!(
                    "{}_{}",
                    name_prefix::FRONTEND,
                    frontend_service::REQUEST_DURATION_SECONDS
                )
        })
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
        .expect("Histogram metric not found");

    assert_eq!(
        histogram_metric_family.get_field_type(),
        MetricType::HISTOGRAM
    );

    let histogram_metric = histogram_metric_family.get_metric();

    assert_eq!(histogram_metric.len(), 1); // We have one metric with label model

    let metric = &histogram_metric[0];
    let histogram = metric.get_histogram();

    let buckets = histogram.get_bucket();

    let mut found = false;
393
394
395
396
397
398
    let mut expected_count = 0;
    for bucket_idx in 1..buckets.len() {
        if buckets[bucket_idx].get_upper_bound() >= 2.5
            && buckets[bucket_idx - 1].get_upper_bound() < 2.5
        {
            found = true;
399
            assert_eq!(
400
401
402
                buckets[bucket_idx].get_cumulative_count(),
                1,
                "Observation should be counted in the bucket containing 2.5"
403
            );
404
            expected_count = 1;
405
406
        } else {
            assert_eq!(
407
408
                buckets[bucket_idx].get_cumulative_count(),
                expected_count,
409
410
411
412
413
414
415
416
417
418
                "No observations should be in this bucket"
            );
        }
    }

    assert!(found, "The expected bucket was not found");
    // ==== ChatCompletions / Stream / Success ====

    // ==== ChatCompletions / Unary / Success ====
    request.stream = Some(false);
Paul Hendricks's avatar
Paul Hendricks committed
419
420

    // ALLOW: max_tokens is deprecated in favor of completion_usage_tokens
421
422
423
    request.max_tokens = Some(0);

    let future = client
424
        .post(format!("http://localhost:{}/v1/chat/completions", port))
425
426
427
428
429
430
431
432
433
434
435
436
        .json(&request)
        .send();

    let response = future.await.unwrap();

    assert!(response.status().is_success(), "{:?}", response);
    inc_counter(
        Endpoint::ChatCompletions,
        RequestType::Unary,
        Status::Success,
        &mut foo_counters,
    );
437
438
    compare_counters(&metrics, "foo", &foo_counters);
    compare_counters(&metrics, "bar", &bar_counters);
439
440
441
442
    // ==== ChatCompletions / Unary / Success ====

    // ==== ChatCompletions / Stream / Error ====
    request.model = "bar".to_string();
Paul Hendricks's avatar
Paul Hendricks committed
443
444

    // ALLOW: max_tokens is deprecated in favor of completion_usage_tokens
445
446
447
448
    request.max_tokens = Some(0);
    request.stream = Some(true);

    let response = client
449
        .post(format!("http://localhost:{}/v1/chat/completions", port))
450
451
452
453
454
455
456
457
458
459
460
461
        .json(&request)
        .send()
        .await
        .unwrap();

    assert_eq!(response.status(), StatusCode::FORBIDDEN);
    inc_counter(
        Endpoint::ChatCompletions,
        RequestType::Stream,
        Status::Error,
        &mut bar_counters,
    );
462
463
    compare_counters(&metrics, "foo", &foo_counters);
    compare_counters(&metrics, "bar", &bar_counters);
464
465
466
467
468
469
    // ==== ChatCompletions / Stream / Error ====

    // ==== ChatCompletions / Unary / Error ====
    request.stream = Some(false);

    let response = client
470
        .post(format!("http://localhost:{}/v1/chat/completions", port))
471
472
473
474
475
476
477
478
479
480
481
482
        .json(&request)
        .send()
        .await
        .unwrap();

    assert_eq!(response.status(), StatusCode::FORBIDDEN);
    inc_counter(
        Endpoint::ChatCompletions,
        RequestType::Unary,
        Status::Error,
        &mut bar_counters,
    );
483
484
    compare_counters(&metrics, "foo", &foo_counters);
    compare_counters(&metrics, "bar", &bar_counters);
485
486
487
    // ==== ChatCompletions / Unary / Error ====

    // ==== Completions / Unary / Error ====
488
    let mut request = dynamo_protocols::types::CreateCompletionRequestArgs::default()
489
490
491
492
493
494
        .model("bar")
        .prompt("hi")
        .build()
        .unwrap();

    let response = client
495
        .post(format!("http://localhost:{}/v1/completions", port))
496
497
498
499
500
501
502
503
504
505
506
507
        .json(&request)
        .send()
        .await
        .unwrap();

    assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
    inc_counter(
        Endpoint::Completions,
        RequestType::Unary,
        Status::Error,
        &mut bar_counters,
    );
508
509
    compare_counters(&metrics, "foo", &foo_counters);
    compare_counters(&metrics, "bar", &bar_counters);
510
511
512
513
514
515
    // ==== Completions / Unary / Error ====

    // ==== Completions / Stream / Error ====
    request.stream = Some(true);

    let response = client
516
        .post(format!("http://localhost:{}/v1/completions", port))
517
518
519
520
521
522
523
524
525
526
527
528
        .json(&request)
        .send()
        .await
        .unwrap();

    assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
    inc_counter(
        Endpoint::Completions,
        RequestType::Stream,
        Status::Error,
        &mut bar_counters,
    );
529
530
    compare_counters(&metrics, "foo", &foo_counters);
    compare_counters(&metrics, "bar", &bar_counters);
531
532
533
534
535
536
537
    // ==== Completions / Stream / Error ====

    // =========== Test Invalid Request ===========
    // send a completion request to a chat endpoint
    request.stream = Some(false);

    let response = client
538
        .post(format!("http://localhost:{}/v1/chat/completions", port))
539
540
541
542
543
        .json(&request)
        .send()
        .await
        .unwrap();

544
    assert_eq!(response.status(), StatusCode::BAD_REQUEST, "{:?}", response);
545
546
547

    // =========== Query /metrics endpoint ===========
    let response = client
548
        .get(format!("http://localhost:{}/metrics", port))
549
550
551
552
553
554
555
556
557
558
        .send()
        .await
        .unwrap();

    assert!(response.status().is_success(), "{:?}", response);
    println!("{}", response.text().await.unwrap());

    cancel_token.cancel();
    task.await.unwrap().unwrap();
}
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576

// === HTTP Client Tests ===

/// Wait for the HTTP service to be ready by checking its health endpoint
async fn wait_for_service_ready(port: u16) {
    let start = tokio::time::Instant::now();
    let timeout = tokio::time::Duration::from_secs(5);
    loop {
        match reqwest::get(&format!("http://localhost:{}/health", port)).await {
            Ok(_) => break,
            Err(_) if start.elapsed() < timeout => {
                tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
            }
            Err(e) => panic!("Service failed to start within timeout: {}", e),
        }
    }
}

577
578
579
580
// NOTE: BYOT (Bring Your Own Type) client tests were removed during the
// upstream async-openai migration. They depended on the forked
// dynamo_protocols::config and http::client modules which no longer exist.
// TODO: Rewrite these tests using the upstream async-openai client.
Ryan Olson's avatar
Ryan Olson committed
581
582
#[tokio::test]
async fn test_client_disconnect_cancellation_unary() {
583
    let port = get_random_port().await;
584
585
586
    let service = HttpService::builder()
        .enable_chat_endpoints(true)
        .enable_cmpl_endpoints(true)
587
        .port(port)
588
589
        .build()
        .unwrap();
Ryan Olson's avatar
Ryan Olson committed
590
591
592
593
594
595
596
597
598
599
    let state = service.state_clone();
    let manager = state.manager();

    let token = CancellationToken::new();
    let cancel_token = token.clone();

    // Start the service
    let task = tokio::spawn(async move { service.run(token).await });

    // Wait for service to be ready
600
    wait_for_service_ready(port).await;
Ryan Olson's avatar
Ryan Olson committed
601
602

    // Create a long-running engine (10 seconds)
603
    let card = ModelDeploymentCard::with_name_only("slow-model");
Ryan Olson's avatar
Ryan Olson committed
604
605
    let long_running_engine = Arc::new(LongRunningEngine::new(10_000));
    manager
606
        .add_chat_completions_model("slow-model", card.mdcsum(), long_running_engine.clone())
Ryan Olson's avatar
Ryan Olson committed
607
608
609
610
        .unwrap();

    let client = reqwest::Client::new();

611
612
613
    let message = dynamo_protocols::types::ChatCompletionRequestMessage::User(
        dynamo_protocols::types::ChatCompletionRequestUserMessage {
            content: dynamo_protocols::types::ChatCompletionRequestUserMessageContent::Text(
Ryan Olson's avatar
Ryan Olson committed
614
615
616
617
618
619
                "This will take a long time".to_string(),
            ),
            name: None,
        },
    );

620
    let request = dynamo_protocols::types::CreateChatCompletionRequestArgs::default()
Ryan Olson's avatar
Ryan Olson committed
621
622
623
624
625
626
627
628
629
630
631
        .model("slow-model")
        .messages(vec![message])
        .stream(false) // Test unary response
        .build()
        .expect("Failed to build request");

    // Start the request and cancel it after 1 second
    let start_time = std::time::Instant::now();

    let request_future = async {
        client
632
            .post(format!("http://localhost:{}/v1/chat/completions", port))
Ryan Olson's avatar
Ryan Olson committed
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
            .json(&request)
            .send()
            .await
    };

    // Use timeout to simulate client disconnect after 1 second
    let result = timeout(std::time::Duration::from_millis(1000), request_future).await;

    let elapsed = start_time.elapsed();

    // The request should timeout (simulating client disconnect)
    assert!(result.is_err(), "Request should have timed out");

    // Give the service a moment to detect the disconnect and propagate cancellation
    tokio::time::sleep(std::time::Duration::from_millis(500)).await;

    // Verify the engine was cancelled
    assert!(
        long_running_engine.was_cancelled(),
        "Engine should have been cancelled due to client disconnect"
    );

    // Verify cancellation happened quickly (within 2 seconds, not the full 10 seconds)
    assert!(
        elapsed < std::time::Duration::from_secs(2),
        "Cancellation should have propagated quickly, took {:?}",
        elapsed
    );

    tracing::info!(
        "✅ Client disconnect test passed! Request cancelled in {:?}, engine detected cancellation",
        elapsed
    );

    cancel_token.cancel();
    task.await.unwrap().unwrap();
}

#[tokio::test]
async fn test_client_disconnect_cancellation_streaming() {
    dynamo_runtime::logging::init();

675
    let port = get_random_port().await;
676
677
678
    let service = HttpService::builder()
        .enable_chat_endpoints(true)
        .enable_cmpl_endpoints(true)
679
        .port(port)
680
681
        .build()
        .unwrap();
Ryan Olson's avatar
Ryan Olson committed
682
683
684
685
686
687
688
689
690
691
    let state = service.state_clone();
    let manager = state.manager();

    let token = CancellationToken::new();
    let cancel_token = token.clone();

    // Start the service
    let task = tokio::spawn(async move { service.run(token).await });

    // Wait for service to be ready
692
    wait_for_service_ready(port).await;
Ryan Olson's avatar
Ryan Olson committed
693
694

    // Create a long-running engine (10 seconds)
695
    let card = ModelDeploymentCard::with_name_only("slow-stream-model");
Ryan Olson's avatar
Ryan Olson committed
696
697
    let long_running_engine = Arc::new(LongRunningEngine::new(10_000));
    manager
698
699
700
701
702
        .add_chat_completions_model(
            "slow-stream-model",
            card.mdcsum(),
            long_running_engine.clone(),
        )
Ryan Olson's avatar
Ryan Olson committed
703
704
705
706
        .unwrap();

    let client = reqwest::Client::new();

707
708
709
    let message = dynamo_protocols::types::ChatCompletionRequestMessage::User(
        dynamo_protocols::types::ChatCompletionRequestUserMessage {
            content: dynamo_protocols::types::ChatCompletionRequestUserMessageContent::Text(
Ryan Olson's avatar
Ryan Olson committed
710
711
712
713
714
715
                "This will stream for a long time".to_string(),
            ),
            name: None,
        },
    );

716
    let request = dynamo_protocols::types::CreateChatCompletionRequestArgs::default()
Ryan Olson's avatar
Ryan Olson committed
717
718
719
720
721
722
723
724
725
726
727
        .model("slow-stream-model")
        .messages(vec![message])
        .stream(true) // Test streaming response
        .build()
        .expect("Failed to build request");

    // Start the request and cancel it after 1 second
    let start_time = std::time::Instant::now();

    let request_future = async {
        let response = client
728
            .post(format!("http://localhost:{}/v1/chat/completions", port))
Ryan Olson's avatar
Ryan Olson committed
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
            .json(&request)
            .send()
            .await
            .unwrap();

        // Start reading the stream, then drop it to simulate client disconnect
        let mut stream = response.bytes_stream();
        tokio::time::sleep(std::time::Duration::from_millis(500)).await;

        // Read one chunk then drop the stream (simulating client disconnect)
        let _ = StreamExt::next(&mut stream).await;
        // Stream gets dropped here when function exits
    };

    // Use timeout to simulate the streaming request timing out
    let _result = timeout(std::time::Duration::from_millis(1500), request_future).await;

    let elapsed = start_time.elapsed();

    // Give the service time to detect the disconnect
    tokio::time::sleep(std::time::Duration::from_millis(1000)).await;

    // Verify the engine was cancelled
    assert!(
        long_running_engine.was_cancelled(),
        "Engine should have been cancelled due to streaming client disconnect"
    );

    // Verify cancellation happened reasonably quickly
    assert!(
        elapsed < std::time::Duration::from_secs(3),
        "Stream cancellation should have propagated reasonably quickly, took {:?}",
        elapsed
    );

    tracing::info!(
        "✅ Streaming client disconnect test passed! Stream cancelled in {:?}, engine detected cancellation",
        elapsed
    );

    cancel_token.cancel();
    task.await.unwrap().unwrap();
}

#[tokio::test]
async fn test_request_id_annotation() {
    // TODO(ryan): make better fixtures, this is too much to test sometime so simple
    dynamo_runtime::logging::init();

778
    let port = get_random_port().await;
779
780
781
    let service = HttpService::builder()
        .enable_chat_endpoints(true)
        .enable_cmpl_endpoints(true)
782
        .port(port)
783
784
        .build()
        .unwrap();
Ryan Olson's avatar
Ryan Olson committed
785
786
787
788
789
790
791
792
793
794
    let state = service.state_clone();
    let manager = state.manager();

    let token = CancellationToken::new();
    let cancel_token = token.clone();

    // Start the service
    let task = tokio::spawn(async move { service.run(token).await });

    // Wait for service to be ready
795
    wait_for_service_ready(port).await;
Ryan Olson's avatar
Ryan Olson committed
796
797

    // Add a counter engine for this test
798
    let card = ModelDeploymentCard::with_name_only("test-model");
Ryan Olson's avatar
Ryan Olson committed
799
800
    let counter_engine = Arc::new(CounterEngine {});
    manager
801
        .add_chat_completions_model("test-model", card.mdcsum(), counter_engine)
Ryan Olson's avatar
Ryan Olson committed
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
        .unwrap();

    // Create reqwest client directly
    let client = reqwest::Client::new();

    // Generate a UUID for the request ID
    let request_uuid = uuid::Uuid::new_v4();

    // Create the request JSON directly
    let request_json = serde_json::json!({
        "model": "test-model",
        "messages": [
            {
                "role": "user",
                "content": "Test request with annotation"
            }
        ],
        "stream": true,
        "max_tokens": 50,
        "nvext": {
            "annotations": ["request_id"]
        }
    });

    // Make the streaming request with custom header
    let response = client
828
        .post(format!("http://localhost:{}/v1/chat/completions", port))
Ryan Olson's avatar
Ryan Olson committed
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
        .header("x-dynamo-request-id", request_uuid.to_string())
        .json(&request_json)
        .send()
        .await
        .expect("Request should succeed");

    assert!(
        response.status().is_success(),
        "Response should be successful"
    );

    // Collect the entire response body as bytes first
    let body_bytes = response
        .bytes()
        .await
        .expect("Failed to read response body");
    let body_text = String::from_utf8_lossy(&body_bytes);

    // Create a cursor from the text and use SseLineCodec to parse it
    let cursor = Cursor::new(body_text.to_string());
    let framed = FramedRead::new(cursor, SseLineCodec::new());
    let annotated_stream = convert_sse_stream::<NvCreateChatCompletionStreamResponse>(framed);

    // Look for the annotation in the stream
    let mut found_request_id_annotation = false;
    let mut received_request_id = None;

    // Process the annotated stream and look for the request_id annotation
    let mut annotated_stream = std::pin::pin!(annotated_stream);
    while let Some(annotated_response) = annotated_stream.next().await {
        // Check if this is a request_id annotation
860
861
862
863
864
865
866
867
868
869
870
871
872
873
        if let Some(event) = &annotated_response.event
            && event == "request_id"
        {
            found_request_id_annotation = true;
            // Extract the request ID from the annotation
            if let Some(comments) = &annotated_response.comment
                && let Some(comment) = comments.first()
            {
                // The comment contains a JSON-encoded string, so we need to parse it
                if let Ok(parsed_value) = serde_json::from_str::<String>(comment) {
                    received_request_id = Some(parsed_value);
                } else {
                    // Fallback: remove quotes manually if JSON parsing fails
                    received_request_id = Some(comment.trim_matches('"').to_string());
Ryan Olson's avatar
Ryan Olson committed
874
875
                }
            }
876
            break;
Ryan Olson's avatar
Ryan Olson committed
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
        }
    }

    // Verify we found the annotation
    assert!(
        found_request_id_annotation,
        "Should have received request_id annotation in the stream"
    );

    // Verify the request ID matches what we sent
    assert!(
        received_request_id.is_some(),
        "Should have received the request ID in the annotation"
    );

    let received_uuid_str = received_request_id.unwrap();
    assert_eq!(
        received_uuid_str,
        request_uuid.to_string(),
        "Received request ID should match the one we sent: expected {}, got {}",
        request_uuid,
        received_uuid_str
    );

    tracing::info!(
        "✅ Request ID annotation test passed! Sent UUID: {}, Received: {}",
        request_uuid,
        received_uuid_str
    );

    cancel_token.cancel();
    task.await.unwrap().unwrap();
}