Unverified Commit a9e0891c authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: adding http clients and recorded response stream (#1919)

parent 4128d583
...@@ -45,7 +45,7 @@ dynamo-tokens = { path = "lib/tokens", version = "0.3.2" } ...@@ -45,7 +45,7 @@ dynamo-tokens = { path = "lib/tokens", version = "0.3.2" }
# External dependencies # External dependencies
anyhow = { version = "1" } anyhow = { version = "1" }
async-nats = { version = "0.40", features = ["service"] } async-nats = { version = "0.40", features = ["service"] }
async-openai = { version = "0.29.0" } async-openai = { version = "0.29.0", features = ["rustls", "byot"] }
async-stream = { version = "0.3" } async-stream = { version = "0.3" }
async-trait = { version = "0.1" } async-trait = { version = "0.1" }
async_zmq = { version = "0.4.0" } async_zmq = { version = "0.4.0" }
......
...@@ -13,4 +13,5 @@ ...@@ -13,4 +13,5 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
pub mod client;
pub mod service; pub mod service;
This diff is collapsed.
...@@ -25,6 +25,7 @@ pub mod local_model; ...@@ -25,6 +25,7 @@ pub mod local_model;
pub mod mocker; pub mod mocker;
pub mod model_card; pub mod model_card;
pub mod model_type; pub mod model_type;
pub mod perf;
pub mod preprocessor; pub mod preprocessor;
pub mod protocols; pub mod protocols;
pub mod recorder; pub mod recorder;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Performance recording and analysis for streaming LLM responses
//!
//! This module provides mechanisms to record streaming responses with minimal overhead
//! during collection, then analyze the recorded data for performance insights.
use futures::Stream;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use tokio::sync::oneshot;
// Import the runtime types we need
use dynamo_runtime::engine::{
AsyncEngineContext, AsyncEngineContextProvider, AsyncEngineStream, Data, DataStream,
EngineStream, ResponseStream,
};
use std::sync::Arc;
/// Type alias for a receiver of recorded stream data
pub type RecordedStreamReceiver<R> = oneshot::Receiver<RecordedStream<R>>;
/// Type alias for the return type of recording functions
pub type RecordingResult<R> = (EngineStream<R>, RecordedStreamReceiver<R>);
/// A response wrapper that adds timing information with minimal overhead
#[derive(Debug, Clone)]
pub struct TimestampedResponse<T> {
/// The actual response data
pub response: T,
/// High-resolution timestamp when this response was recorded
pub timestamp: Instant,
/// Sequence number in the stream (0-based)
pub sequence_number: usize,
}
impl<T> TimestampedResponse<T> {
/// Create a new timestamped response
pub fn new(response: T, sequence_number: usize) -> Self {
Self {
response,
timestamp: Instant::now(),
sequence_number,
}
}
/// Get the response data
pub fn data(&self) -> &T {
&self.response
}
/// Get the elapsed time since stream start
pub fn elapsed_since(&self, start_time: Instant) -> Duration {
self.timestamp.duration_since(start_time)
}
}
/// Trait for requests that can provide hints about expected response count
/// This enables capacity pre-allocation for better performance
pub trait CapacityHint {
/// Estimate the number of responses this request might generate
/// Returns None if estimation is not possible
fn estimated_response_count(&self) -> Option<usize>;
}
/// Recording mode determines how the recorder behaves with the stream
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RecordingMode {
/// Pass responses through while recording (scan mode)
/// Stream continues to flow to downstream consumers
Scan,
/// Consume responses as terminus (sink mode)
/// Stream ends at the recorder
Sink,
}
/// Container for recorded streaming responses.
/// This forms the core object on which analysis is performed.
#[derive(Debug, Clone)]
pub struct RecordedStream<T> {
/// All recorded responses with timestamps
responses: Vec<TimestampedResponse<T>>,
/// When recording started
start_time: Instant,
/// When recording ended
end_time: Instant,
}
impl<T> RecordedStream<T> {
/// Create a new recorded stream from collected responses
pub fn new(
responses: Vec<TimestampedResponse<T>>,
start_time: Instant,
end_time: Instant,
) -> Self {
Self {
responses,
start_time,
end_time,
}
}
/// Get the number of responses recorded
pub fn response_count(&self) -> usize {
self.responses.len()
}
/// Get the total duration of the stream
pub fn total_duration(&self) -> Duration {
self.end_time.duration_since(self.start_time)
}
/// Get the responses recorded
pub fn responses(&self) -> &[TimestampedResponse<T>] {
&self.responses
}
/// Get the start time of the stream
pub fn start_time(&self) -> &Instant {
&self.start_time
}
/// Get the end time of the stream
pub fn end_time(&self) -> &Instant {
&self.end_time
}
}
/// Recording stream that wraps an AsyncEngineStream and records responses
/// Following the pattern of ResponseStream for AsyncEngine compatibility
pub struct RecordingStream<R: Data> {
/// The wrapped stream
stream: DataStream<R>,
/// Context from the original stream
ctx: Arc<dyn AsyncEngineContext>,
/// Recording mode
mode: RecordingMode,
/// Recorded responses
responses: Vec<TimestampedResponse<R>>,
/// When recording started
start_time: Instant,
/// Channel to send recorded data when stream completes
recorded_tx: Option<oneshot::Sender<RecordedStream<R>>>,
}
impl<R: Data> Unpin for RecordingStream<R> {}
impl<R: Data + Clone> RecordingStream<R> {
/// Create a new recording stream from a raw stream and context
pub fn from_stream_and_context(
stream: DataStream<R>,
ctx: Arc<dyn AsyncEngineContext>,
mode: RecordingMode,
capacity: Option<usize>,
recorded_tx: oneshot::Sender<RecordedStream<R>>,
) -> Self {
let mut responses = Vec::new();
if let Some(cap) = capacity {
responses.reserve(cap);
}
Self {
stream,
ctx,
mode,
responses,
start_time: Instant::now(),
recorded_tx: Some(recorded_tx),
}
}
/// Create a new recording stream from an AsyncEngineStream (private constructor)
fn from_async_engine_stream(
stream: EngineStream<R>,
mode: RecordingMode,
capacity: Option<usize>,
recorded_tx: oneshot::Sender<RecordedStream<R>>,
) -> Self {
let ctx = stream.context();
Self::from_stream_and_context(stream, ctx, mode, capacity, recorded_tx)
}
/// Convert to Pin<Box<dyn AsyncEngineStream<R>>>
pub fn into_async_engine_stream(self) -> EngineStream<R> {
Box::pin(self)
}
}
impl<R: Data + Clone> Stream for RecordingStream<R> {
type Item = R;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.as_mut().get_mut();
match Pin::new(&mut this.stream).poll_next(cx) {
Poll::Ready(Some(item)) => {
// Always capture timestamp first (cheap operation)
let timestamp = Instant::now();
let sequence_number = this.responses.len();
match this.mode {
RecordingMode::Scan => {
// Clone for recording, pass original through
let timestamped = TimestampedResponse {
response: item.clone(),
timestamp,
sequence_number,
};
this.responses.push(timestamped);
Poll::Ready(Some(item)) // Pass through original
}
RecordingMode::Sink => {
// Move item directly into recording (no clone needed)
let timestamped = TimestampedResponse {
response: item, // Move, don't clone
timestamp,
sequence_number,
};
this.responses.push(timestamped);
// Continue consuming but don't emit
// self.poll_next(cx)
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
Poll::Ready(None) => {
// Stream ended - send recorded data
if let Some(tx) = this.recorded_tx.take() {
let recorded = RecordedStream::new(
std::mem::take(&mut this.responses),
this.start_time,
Instant::now(),
);
let _ = tx.send(recorded); // Ignore if receiver dropped
}
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
}
impl<R: Data + Clone> AsyncEngineStream<R> for RecordingStream<R> {}
impl<R: Data + Clone> AsyncEngineContextProvider for RecordingStream<R> {
fn context(&self) -> Arc<dyn AsyncEngineContext> {
self.ctx.clone()
}
}
impl<R: Data + Clone> std::fmt::Debug for RecordingStream<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RecordingStream")
.field("mode", &self.mode)
.field("responses_count", &self.responses.len())
.field("ctx", &self.ctx)
.finish()
}
}
/// Create a recording stream that wraps an AsyncEngineStream
/// Returns a pinned stream and a receiver for the recorded data
pub fn record_stream<R: Data + Clone>(
stream: EngineStream<R>,
mode: RecordingMode,
) -> RecordingResult<R> {
let (tx, rx) = oneshot::channel();
let recording_stream = RecordingStream::from_async_engine_stream(stream, mode, None, tx);
let boxed_stream = Box::pin(recording_stream);
(boxed_stream, rx)
}
/// Create a recording stream from a raw stream and context
/// Returns a pinned stream and a receiver for the recorded data
pub fn record_stream_with_context<R: Data + Clone>(
stream: DataStream<R>,
ctx: Arc<dyn AsyncEngineContext>,
mode: RecordingMode,
) -> RecordingResult<R> {
let (tx, rx) = oneshot::channel();
let recording_stream = RecordingStream::from_stream_and_context(stream, ctx, mode, None, tx);
let boxed_stream = Box::pin(recording_stream);
(boxed_stream, rx)
}
/// Create a recording stream with capacity hint
pub fn record_stream_with_capacity<R: Data + Clone>(
stream: EngineStream<R>,
mode: RecordingMode,
capacity: usize,
) -> RecordingResult<R> {
let (tx, rx) = oneshot::channel();
let recording_stream =
RecordingStream::from_async_engine_stream(stream, mode, Some(capacity), tx);
let boxed_stream = Box::pin(recording_stream);
(boxed_stream, rx)
}
/// Create a recording stream with capacity hint from request
pub fn record_stream_with_request_hint<R: Data + Clone, Req: CapacityHint>(
stream: EngineStream<R>,
mode: RecordingMode,
request: &Req,
) -> RecordingResult<R> {
let capacity = request.estimated_response_count();
match capacity {
Some(cap) => record_stream_with_capacity(stream, mode, cap),
None => record_stream(stream, mode),
}
}
/// Create a recording stream from a raw stream and context with capacity hint
pub fn record_stream_with_context_and_capacity<R: Data + Clone>(
stream: DataStream<R>,
ctx: Arc<dyn AsyncEngineContext>,
mode: RecordingMode,
capacity: usize,
) -> RecordingResult<R> {
let (tx, rx) = oneshot::channel();
let recording_stream =
RecordingStream::from_stream_and_context(stream, ctx, mode, Some(capacity), tx);
let boxed_stream = Box::pin(recording_stream);
(boxed_stream, rx)
}
/// Create a recording stream from ResponseStream (convenience wrapper)
pub fn record_response_stream<R: Data + Clone>(
response_stream: Pin<Box<ResponseStream<R>>>,
mode: RecordingMode,
) -> RecordingResult<R> {
record_stream(response_stream, mode)
}
#[cfg(test)]
mod tests {
use super::*;
use dynamo_runtime::engine::ResponseStream;
use futures::stream;
use std::time::Duration;
#[test]
fn test_timestamped_response_creation() {
let response = "test response";
let timestamped = TimestampedResponse::new(response, 0);
assert_eq!(timestamped.response, response);
assert_eq!(timestamped.sequence_number, 0);
assert_eq!(timestamped.data(), &response);
}
#[test]
fn test_recorded_stream_analysis() {
let start_time = Instant::now();
// Create mock responses with known timing
let responses = vec![
TimestampedResponse {
response: "response1",
timestamp: start_time,
sequence_number: 0,
},
TimestampedResponse {
response: "response2",
timestamp: start_time + Duration::from_millis(100),
sequence_number: 1,
},
TimestampedResponse {
response: "response3",
timestamp: start_time + Duration::from_millis(250),
sequence_number: 2,
},
];
let end_time = start_time + Duration::from_millis(250);
let recorded = RecordedStream::new(responses, start_time, end_time);
assert_eq!(recorded.response_count(), 3);
assert_eq!(recorded.total_duration(), Duration::from_millis(250));
}
#[test]
fn test_performance_metrics_conversion() {
let start_time = Instant::now();
let responses = vec![
TimestampedResponse {
response: "test",
timestamp: start_time + Duration::from_millis(50),
sequence_number: 0,
},
TimestampedResponse {
response: "test",
timestamp: start_time + Duration::from_millis(150),
sequence_number: 1,
},
];
let end_time = start_time + Duration::from_millis(150);
let recorded = RecordedStream::new(responses, start_time, end_time);
assert_eq!(recorded.response_count(), 2);
assert_eq!(recorded.total_duration(), Duration::from_millis(150));
}
#[tokio::test]
async fn test_recording_stream_scan_mode() {
use futures::StreamExt;
// Create a simple test stream
let test_data = vec!["token1", "token2", "token3"];
let base_stream = stream::iter(test_data.clone());
// Create a mock context for the stream
let ctx = Arc::new(MockContext::new());
// Record the stream in scan mode using the simplified API
let (recorded_stream, recording_rx) =
record_stream_with_context(Box::pin(base_stream), ctx, RecordingMode::Scan);
// Consume the stream normally (pass-through mode)
let collected_responses: Vec<_> = recorded_stream.collect().await;
// Verify the responses passed through unchanged
assert_eq!(collected_responses, test_data);
// Get the recorded data
let recorded = recording_rx.await.unwrap();
assert_eq!(recorded.response_count(), 3);
assert_eq!(recorded.responses[0].response, "token1");
assert_eq!(recorded.responses[1].response, "token2");
assert_eq!(recorded.responses[2].response, "token3");
// Verify timing was recorded
assert!(recorded.total_duration() > Duration::from_nanos(0));
}
#[tokio::test]
async fn test_recording_stream_sink_mode() {
use futures::StreamExt;
// Create a simple test stream
let test_data = vec!["token1", "token2", "token3"];
let base_stream = stream::iter(test_data.clone());
// Create a mock context for the stream
let ctx = Arc::new(MockContext::new());
// Record the stream in sink mode using the simplified API
let (recorded_stream, recording_rx) =
record_stream_with_context(Box::pin(base_stream), ctx, RecordingMode::Sink);
// In sink mode, the stream should complete without emitting items
let collected_responses: Vec<_> = recorded_stream.collect().await;
assert_eq!(collected_responses, Vec::<&str>::new());
// Get the recorded data - should contain all original items
let recorded = recording_rx.await.unwrap();
assert_eq!(recorded.response_count(), 3);
assert_eq!(recorded.responses[0].response, "token1");
assert_eq!(recorded.responses[1].response, "token2");
assert_eq!(recorded.responses[2].response, "token3");
// Verify timing was recorded
assert!(recorded.total_duration() > Duration::from_nanos(0));
}
#[tokio::test]
async fn test_recording_stream_from_response_stream() {
use futures::StreamExt;
// Create a simple test stream
let test_data = vec!["token1", "token2", "token3"];
let base_stream = stream::iter(test_data.clone());
// Create a ResponseStream (the traditional way)
let ctx = Arc::new(MockContext::new());
let response_stream = ResponseStream::new(Box::pin(base_stream), ctx);
// Use the convenience API for ResponseStream
let (recorded_stream, recording_rx) =
record_response_stream(response_stream, RecordingMode::Scan);
// Consume the stream normally (pass-through mode)
let collected_responses: Vec<_> = recorded_stream.collect().await;
// Verify the responses passed through unchanged
assert_eq!(collected_responses, test_data);
// Get the recorded data
let recorded = recording_rx.await.unwrap();
assert_eq!(recorded.response_count(), 3);
assert_eq!(recorded.responses[0].response, "token1");
assert_eq!(recorded.responses[1].response, "token2");
assert_eq!(recorded.responses[2].response, "token3");
// Verify timing was recorded
assert!(recorded.total_duration() > Duration::from_nanos(0));
}
// Mock context for testing
#[derive(Debug)]
struct MockContext {
id: String,
}
impl MockContext {
fn new() -> Self {
Self {
id: "test-context".to_string(),
}
}
}
#[async_trait::async_trait]
impl AsyncEngineContext for MockContext {
fn id(&self) -> &str {
&self.id
}
fn stop(&self) {
// No-op for testing
}
fn stop_generating(&self) {
// No-op for testing
}
fn kill(&self) {
// No-op for testing
}
fn is_stopped(&self) -> bool {
false
}
fn is_killed(&self) -> bool {
false
}
async fn stopped(&self) {
// No-op for testing
}
async fn killed(&self) {
// No-op for testing
}
}
}
...@@ -14,12 +14,18 @@ ...@@ -14,12 +14,18 @@
// limitations under the License. // limitations under the License.
use anyhow::Error; use anyhow::Error;
use async_openai::config::OpenAIConfig;
use async_stream::stream; use async_stream::stream;
use dynamo_llm::http::service::{ use dynamo_llm::http::{
client::{
GenericBYOTClient, HttpClientConfig, HttpRequestContext, NvCustomClient, PureOpenAIClient,
},
service::{
error::HttpError, error::HttpError,
metrics::{Endpoint, RequestType, Status}, metrics::{Endpoint, RequestType, Status},
service_v2::HttpService, service_v2::HttpService,
Metrics, Metrics,
},
}; };
use dynamo_llm::protocols::{ use dynamo_llm::protocols::{
openai::{ openai::{
...@@ -29,13 +35,16 @@ use dynamo_llm::protocols::{ ...@@ -29,13 +35,16 @@ use dynamo_llm::protocols::{
Annotated, Annotated,
}; };
use dynamo_runtime::{ use dynamo_runtime::{
engine::AsyncEngineContext,
pipeline::{ pipeline::{
async_trait, AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream, SingleIn, async_trait, AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream, SingleIn,
}, },
CancellationToken, CancellationToken,
}; };
use futures::StreamExt;
use prometheus::{proto::MetricType, Registry}; use prometheus::{proto::MetricType, Registry};
use reqwest::StatusCode; use reqwest::StatusCode;
use rstest::*;
use std::sync::Arc; use std::sync::Arc;
struct CounterEngine {} struct CounterEngine {}
...@@ -470,3 +479,404 @@ async fn test_http_service() { ...@@ -470,3 +479,404 @@ async fn test_http_service() {
cancel_token.cancel(); cancel_token.cancel();
task.await.unwrap().unwrap(); task.await.unwrap().unwrap();
} }
// === HTTP Client Tests ===
/// Wait for the HTTP service to be ready by checking its health endpoint
async fn wait_for_service_ready(port: u16) {
let start = tokio::time::Instant::now();
let timeout = tokio::time::Duration::from_secs(5);
loop {
match reqwest::get(&format!("http://localhost:{}/health", port)).await {
Ok(_) => break,
Err(_) if start.elapsed() < timeout => {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
Err(e) => panic!("Service failed to start within timeout: {}", e),
}
}
}
#[fixture]
fn service_with_engines(
#[default(8990)] port: u16,
) -> (HttpService, Arc<CounterEngine>, Arc<AlwaysFailEngine>) {
let service = HttpService::builder().port(port).build().unwrap();
let manager = service.model_manager();
let counter = Arc::new(CounterEngine {});
let failure = Arc::new(AlwaysFailEngine {});
manager
.add_chat_completions_model("foo", counter.clone())
.unwrap();
manager
.add_chat_completions_model("bar", failure.clone())
.unwrap();
manager
.add_completions_model("bar", failure.clone())
.unwrap();
(service, counter, failure)
}
#[fixture]
fn pure_openai_client(#[default(8990)] port: u16) -> PureOpenAIClient {
let config = HttpClientConfig {
openai_config: OpenAIConfig::new().with_api_base(format!("http://localhost:{}/v1", port)),
verbose: false,
};
PureOpenAIClient::new(config)
}
#[fixture]
fn nv_custom_client(#[default(8991)] port: u16) -> NvCustomClient {
let config = HttpClientConfig {
openai_config: OpenAIConfig::new().with_api_base(format!("http://localhost:{}/v1", port)),
verbose: false,
};
NvCustomClient::new(config)
}
#[fixture]
fn generic_byot_client(#[default(8992)] port: u16) -> GenericBYOTClient {
let config = HttpClientConfig {
openai_config: OpenAIConfig::new().with_api_base(format!("http://localhost:{}/v1", port)),
verbose: false,
};
GenericBYOTClient::new(config)
}
#[rstest]
#[tokio::test]
async fn test_pure_openai_client(
#[with(8990)] service_with_engines: (HttpService, Arc<CounterEngine>, Arc<AlwaysFailEngine>),
#[with(8990)] pure_openai_client: PureOpenAIClient,
) {
let (service, _counter, _failure) = service_with_engines;
let token = CancellationToken::new();
let cancel_token = token.clone();
// Start the service
let task = tokio::spawn(async move { service.run(token).await });
// Wait for service to be ready
wait_for_service_ready(8990).await;
// Test successful streaming request
let request = async_openai::types::CreateChatCompletionRequestArgs::default()
.model("foo")
.messages(vec![
async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage {
content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(
"Hi".to_string(),
),
name: None,
},
),
])
.stream(true)
.max_tokens(50u32)
.build()
.unwrap();
let result = pure_openai_client.chat_stream(request).await;
assert!(result.is_ok(), "PureOpenAI client should succeed");
let (mut stream, _context) = result.unwrap().dissolve();
let mut count = 0;
while let Some(response) = stream.next().await {
count += 1;
assert!(response.is_ok(), "Response should be ok");
if count >= 3 {
break; // Don't consume entire stream
}
}
assert!(count > 0, "Should receive at least one response");
// Test error case with invalid model
let request = async_openai::types::CreateChatCompletionRequestArgs::default()
.model("bar") // This model will fail
.messages(vec![
async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage {
content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(
"Hi".to_string(),
),
name: None,
},
),
])
.stream(true)
.max_tokens(50u32)
.build()
.unwrap();
let result = pure_openai_client.chat_stream(request).await;
assert!(
result.is_ok(),
"Client should return stream even for failing model"
);
let (mut stream, _context) = result.unwrap().dissolve();
if let Some(response) = stream.next().await {
assert!(
response.is_err(),
"Response should be error for failing model"
);
}
// Test context management
let ctx = HttpRequestContext::new();
let request = async_openai::types::CreateChatCompletionRequestArgs::default()
.model("foo")
.messages(vec![
async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage {
content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(
"Hi".to_string(),
),
name: None,
},
),
])
.stream(true)
.max_tokens(50u32)
.build()
.unwrap();
let result = pure_openai_client
.chat_stream_with_context(request, ctx.clone())
.await;
assert!(result.is_ok(), "Context-based request should succeed");
let (_stream, context) = result.unwrap().dissolve();
assert_eq!(context.id(), ctx.id(), "Context ID should match");
cancel_token.cancel();
task.await.unwrap().unwrap();
}
#[rstest]
#[tokio::test]
async fn test_nv_custom_client(
#[with(8991)] service_with_engines: (HttpService, Arc<CounterEngine>, Arc<AlwaysFailEngine>),
#[with(8991)] nv_custom_client: NvCustomClient,
) {
let (service, _counter, _failure) = service_with_engines;
let token = CancellationToken::new();
let cancel_token = token.clone();
// Start the service
let task = tokio::spawn(async move { service.run(token).await });
// Wait for service to be ready
wait_for_service_ready(8991).await;
// Test successful streaming request
let inner_request = async_openai::types::CreateChatCompletionRequestArgs::default()
.model("foo")
.messages(vec![
async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage {
content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(
"Hi".to_string(),
),
name: None,
},
),
])
.stream(true)
.max_tokens(50u32)
.build()
.unwrap();
let request = NvCreateChatCompletionRequest {
inner: inner_request,
nvext: None,
};
let result = nv_custom_client.chat_stream(request).await;
assert!(result.is_ok(), "NvCustom client should succeed");
let (mut stream, _context) = result.unwrap().dissolve();
let mut count = 0;
while let Some(response) = stream.next().await {
count += 1;
assert!(response.is_ok(), "Response should be ok");
if count >= 3 {
break; // Don't consume entire stream
}
}
assert!(count > 0, "Should receive at least one response");
// Test error case with invalid model
let inner_request = async_openai::types::CreateChatCompletionRequestArgs::default()
.model("bar") // This model will fail
.messages(vec![
async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage {
content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(
"Hi".to_string(),
),
name: None,
},
),
])
.stream(true)
.max_tokens(50u32)
.build()
.unwrap();
let request = NvCreateChatCompletionRequest {
inner: inner_request,
nvext: None,
};
let result = nv_custom_client.chat_stream(request).await;
assert!(
result.is_ok(),
"Client should return stream even for failing model"
);
let (mut stream, _context) = result.unwrap().dissolve();
if let Some(response) = stream.next().await {
assert!(
response.is_err(),
"Response should be error for failing model"
);
}
// Test context management
let ctx = HttpRequestContext::new();
let inner_request = async_openai::types::CreateChatCompletionRequestArgs::default()
.model("foo")
.messages(vec![
async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage {
content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(
"Hi".to_string(),
),
name: None,
},
),
])
.stream(true)
.max_tokens(50u32)
.build()
.unwrap();
let request = NvCreateChatCompletionRequest {
inner: inner_request,
nvext: None,
};
let result = nv_custom_client
.chat_stream_with_context(request, ctx.clone())
.await;
assert!(result.is_ok(), "Context-based request should succeed");
let (_stream, context) = result.unwrap().dissolve();
assert_eq!(context.id(), ctx.id(), "Context ID should match");
cancel_token.cancel();
task.await.unwrap().unwrap();
}
#[rstest]
#[tokio::test]
async fn test_generic_byot_client(
#[with(8992)] service_with_engines: (HttpService, Arc<CounterEngine>, Arc<AlwaysFailEngine>),
#[with(8992)] generic_byot_client: GenericBYOTClient,
) {
let (service, _counter, _failure) = service_with_engines;
let token = CancellationToken::new();
let cancel_token = token.clone();
// Start the service
let task = tokio::spawn(async move { service.run(token).await });
// Wait for service to be ready
wait_for_service_ready(8992).await;
// Test successful streaming request
let request = serde_json::json!({
"model": "foo",
"messages": [
{
"role": "user",
"content": "Hi"
}
],
"stream": true,
"max_tokens": 50
});
let result = generic_byot_client.chat_stream(request).await;
assert!(result.is_ok(), "GenericBYOT client should succeed");
let (mut stream, _context) = result.unwrap().dissolve();
let mut count = 0;
while let Some(response) = stream.next().await {
println!("Response: {:?}", response);
count += 1;
assert!(response.is_ok(), "Response should be ok");
if count >= 3 {
break; // Don't consume entire stream
}
}
assert!(count > 0, "Should receive at least one response");
// Test error case with invalid model
let request = serde_json::json!({
"model": "bar", // This model will fail
"messages": [
{
"role": "user",
"content": "Hi"
}
],
"stream": true,
"max_tokens": 50
});
let result = generic_byot_client.chat_stream(request).await;
assert!(
result.is_ok(),
"Client should return stream even for failing model"
);
let (mut stream, _context) = result.unwrap().dissolve();
if let Some(response) = stream.next().await {
assert!(
response.is_err(),
"Response should be error for failing model"
);
}
// Test context management
let ctx = HttpRequestContext::new();
let request = serde_json::json!({
"model": "foo",
"messages": [
{
"role": "user",
"content": "Hi"
}
],
"stream": true,
"max_tokens": 50
});
let result = generic_byot_client
.chat_stream_with_context(request, ctx.clone())
.await;
assert!(result.is_ok(), "Context-based request should succeed");
let (_stream, context) = result.unwrap().dissolve();
assert_eq!(context.id(), ctx.id(), "Context ID should match");
cancel_token.cancel();
task.await.unwrap().unwrap();
}
...@@ -165,7 +165,7 @@ pub trait AsyncEngineContext: Send + Sync + Debug { ...@@ -165,7 +165,7 @@ pub trait AsyncEngineContext: Send + Sync + Debug {
/// ///
/// This trait is implemented by both unary and streaming engine results, allowing /// This trait is implemented by both unary and streaming engine results, allowing
/// uniform access to context information regardless of the operation type. /// uniform access to context information regardless of the operation type.
pub trait AsyncEngineContextProvider: Send + Sync + Debug { pub trait AsyncEngineContextProvider: Send + Debug {
fn context(&self) -> Arc<dyn AsyncEngineContext>; fn context(&self) -> Arc<dyn AsyncEngineContext>;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment