Unverified Commit 98d4abbb authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat: add more metrics to rust frontend (#1315)


Signed-off-by: default avatarHongkuan Zhou <tedzhouhk@gmail.com>
Co-authored-by: default avatarjothomson <jwillthomson19@gmail.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent f8213242
...@@ -407,6 +407,9 @@ impl ...@@ -407,6 +407,9 @@ impl
id: None, id: None,
data: Some(delta), data: Some(delta),
event: None, event: None,
chunk_tokens: None,
input_tokens: None,
output_tokens: None,
comment: None, comment: None,
}; };
yield ann; yield ann;
...@@ -566,6 +569,9 @@ impl AsyncEngine<SingleIn<CompletionRequest>, ManyOut<Annotated<CompletionRespon ...@@ -566,6 +569,9 @@ impl AsyncEngine<SingleIn<CompletionRequest>, ManyOut<Annotated<CompletionRespon
id: None, id: None,
data: Some(inner), data: Some(inner),
event: None, event: None,
chunk_tokens: None,
input_tokens: None,
output_tokens: None,
comment: None, comment: None,
}; };
yield ann; yield ann;
......
...@@ -202,7 +202,7 @@ impl ...@@ -202,7 +202,7 @@ impl
let response = NvCreateChatCompletionStreamResponse { let response = NvCreateChatCompletionStreamResponse {
inner, inner,
}; };
yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None }; yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, chunk_tokens: None, input_tokens: None, output_tokens: None, comment: None };
id += 1; id += 1;
} }
...@@ -210,7 +210,7 @@ impl ...@@ -210,7 +210,7 @@ impl
let response = NvCreateChatCompletionStreamResponse { let response = NvCreateChatCompletionStreamResponse {
inner, inner,
}; };
yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None }; yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, chunk_tokens: None, input_tokens: None, output_tokens: None, comment: None };
}; };
Ok(ResponseStream::new(Box::pin(output), ctx)) Ok(ResponseStream::new(Box::pin(output), ctx))
...@@ -234,11 +234,11 @@ impl AsyncEngine<SingleIn<CompletionRequest>, ManyOut<Annotated<CompletionRespon ...@@ -234,11 +234,11 @@ impl AsyncEngine<SingleIn<CompletionRequest>, ManyOut<Annotated<CompletionRespon
for c in chars_string.chars() { for c in chars_string.chars() {
tokio::time::sleep(*TOKEN_ECHO_DELAY).await; tokio::time::sleep(*TOKEN_ECHO_DELAY).await;
let response = deltas.create_choice(0, Some(c.to_string()), None); let response = deltas.create_choice(0, Some(c.to_string()), None);
yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None }; yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, chunk_tokens: None, input_tokens: None, output_tokens: None, comment: None };
id += 1; id += 1;
} }
let response = deltas.create_choice(0, None, Some("stop".to_string())); let response = deltas.create_choice(0, None, Some("stop".to_string()));
yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None }; yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, chunk_tokens: None, input_tokens: None, output_tokens: None, comment: None };
}; };
......
...@@ -3,7 +3,10 @@ ...@@ -3,7 +3,10 @@
use axum::{extract::State, http::StatusCode, response::IntoResponse, routing::get, Router}; use axum::{extract::State, http::StatusCode, response::IntoResponse, routing::get, Router};
use prometheus::{Encoder, HistogramOpts, HistogramVec, IntCounterVec, IntGaugeVec, Opts}; use prometheus::{Encoder, HistogramOpts, HistogramVec, IntCounterVec, IntGaugeVec, Opts};
use std::{sync::Arc, time::Instant}; use std::{
sync::Arc,
time::{Duration, Instant},
};
pub use prometheus::Registry; pub use prometheus::Registry;
...@@ -25,6 +28,10 @@ pub struct Metrics { ...@@ -25,6 +28,10 @@ pub struct Metrics {
request_counter: IntCounterVec, request_counter: IntCounterVec,
inflight_gauge: IntGaugeVec, inflight_gauge: IntGaugeVec,
request_duration: HistogramVec, request_duration: HistogramVec,
input_sequence_length: HistogramVec,
output_sequence_length: HistogramVec,
time_to_first_token: HistogramVec,
inter_token_latency: HistogramVec,
} }
/// RAII object for inflight gauge and request counters /// RAII object for inflight gauge and request counters
...@@ -68,6 +75,20 @@ pub enum Status { ...@@ -68,6 +75,20 @@ pub enum Status {
Error, Error,
} }
/// Track response-specific metrics
pub struct ResponseMetricCollector {
metrics: Arc<Metrics>,
model: String,
start_time: Instant,
// we use is_first_token to distinguish TTFT from ITL. It is true by default and
// flipped to false when the first token is returned and TTFT is published.
is_first_token: bool,
// we track the last response time so that ITL for the newly returned tokens can
// be computed.
last_response_time: Option<Duration>,
osl: usize,
}
impl Default for Metrics { impl Default for Metrics {
fn default() -> Self { fn default() -> Self {
Self::new("nv_llm") Self::new("nv_llm")
...@@ -80,6 +101,10 @@ impl Metrics { ...@@ -80,6 +101,10 @@ impl Metrics {
/// - `{prefix}_http_service_requests_total` - IntCounterVec for the total number of requests processed /// - `{prefix}_http_service_requests_total` - IntCounterVec for the total number of requests processed
/// - `{prefix}_http_service_inflight_requests` - IntGaugeVec for the number of inflight requests /// - `{prefix}_http_service_inflight_requests` - IntGaugeVec for the number of inflight requests
/// - `{prefix}_http_service_request_duration_seconds` - HistogramVec for the duration of requests /// - `{prefix}_http_service_request_duration_seconds` - HistogramVec for the duration of requests
/// - `{prefix}_http_service_input_sequence_tokens` - HistogramVec for input sequence length in tokens
/// - `{prefix}_http_service_output_sequence_tokens` - HistogramVec for output sequence length in tokens
/// - `{prefix}_http_service_time_to_first_token_seconds` - HistogramVec for time to first token in seconds
/// - `{prefix}_http_service_inter_token_latency_seconds` - HistogramVec for inter-token latency in seconds
pub fn new(prefix: &str) -> Self { pub fn new(prefix: &str) -> Self {
let request_counter = IntCounterVec::new( let request_counter = IntCounterVec::new(
Opts::new( Opts::new(
...@@ -111,10 +136,64 @@ impl Metrics { ...@@ -111,10 +136,64 @@ impl Metrics {
) )
.unwrap(); .unwrap();
let input_sequence_length = HistogramVec::new(
HistogramOpts::new(
format!("{}_http_service_input_sequence_tokens", prefix),
"Input sequence length in tokens",
)
.buckets(vec![
0.0, 50.0, 100.0, 500.0, 1000.0, 2000.0, 4000.0, 8000.0, 16000.0, 32000.0, 64000.0,
128000.0,
]),
&["model"],
)
.unwrap();
let output_sequence_length = HistogramVec::new(
HistogramOpts::new(
format!("{}_http_service_output_sequence_tokens", prefix),
"Output sequence length in tokens",
)
.buckets(vec![
0.0, 50.0, 100.0, 500.0, 1000.0, 2000.0, 4000.0, 8000.0, 16000.0, 32000.0,
]),
&["model"],
)
.unwrap();
let time_to_first_token = HistogramVec::new(
HistogramOpts::new(
format!("{}_http_service_time_to_first_token_seconds", prefix),
"Time to first token in seconds",
)
.buckets(vec![
0.0, 0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0,
60.0, 120.0, 240.0, 480.0,
]),
&["model"],
)
.unwrap();
let inter_token_latency = HistogramVec::new(
HistogramOpts::new(
format!("{}_http_service_inter_token_latency_seconds", prefix),
"Inter-token latency in seconds",
)
.buckets(vec![
0.0, 0.001, 0.005, 0.01, 0.015, 0.02, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.0,
]),
&["model"],
)
.unwrap();
Metrics { Metrics {
request_counter, request_counter,
inflight_gauge, inflight_gauge,
request_duration, request_duration,
input_sequence_length,
output_sequence_length,
time_to_first_token,
inter_token_latency,
} }
} }
...@@ -179,6 +258,10 @@ impl Metrics { ...@@ -179,6 +258,10 @@ impl Metrics {
registry.register(Box::new(self.request_counter.clone()))?; registry.register(Box::new(self.request_counter.clone()))?;
registry.register(Box::new(self.inflight_gauge.clone()))?; registry.register(Box::new(self.inflight_gauge.clone()))?;
registry.register(Box::new(self.request_duration.clone()))?; registry.register(Box::new(self.request_duration.clone()))?;
registry.register(Box::new(self.input_sequence_length.clone()))?;
registry.register(Box::new(self.output_sequence_length.clone()))?;
registry.register(Box::new(self.time_to_first_token.clone()))?;
registry.register(Box::new(self.inter_token_latency.clone()))?;
Ok(()) Ok(())
} }
...@@ -199,7 +282,17 @@ impl Metrics { ...@@ -199,7 +282,17 @@ impl Metrics {
RequestType::Unary RequestType::Unary
}; };
InflightGuard::new(self.clone(), model.to_string(), endpoint, request_type) InflightGuard::new(
self.clone(),
model.to_string().to_lowercase(),
endpoint,
request_type,
)
}
/// Create a new [`ResponseMetricCollector`] for collecting per-response metrics (i.e., TTFT, ITL)
pub fn create_response_collector(self: Arc<Self>, model: &str) -> ResponseMetricCollector {
ResponseMetricCollector::new(self, model.to_string().to_lowercase())
} }
} }
...@@ -293,6 +386,76 @@ impl Status { ...@@ -293,6 +386,76 @@ impl Status {
} }
} }
impl ResponseMetricCollector {
fn new(metrics: Arc<Metrics>, model: String) -> Self {
ResponseMetricCollector {
metrics,
model,
is_first_token: true,
last_response_time: None,
start_time: Instant::now(),
osl: 0,
}
}
/// Observe the current output sequence length
pub fn observe_current_osl(&mut self, osl: usize) {
self.osl = osl;
}
/// Observe a response with input sequence length and number of new tokens
pub fn observe_response(&mut self, isl: usize, num_tokens: usize) {
if num_tokens == 0 {
return;
}
if self.is_first_token {
// NOTE: when there are multiple tokens in the first response,
// we use the full response time as TTFT and ignore the ITL
self.is_first_token = false;
// Publish TTFT
let ttft = self.start_time.elapsed().as_secs_f64();
self.metrics
.time_to_first_token
.with_label_values(&[&self.model])
.observe(ttft);
// Publish ISL
// TODO: publish ISL as soon as the tokenization process completes
self.metrics
.input_sequence_length
.with_label_values(&[&self.model])
.observe(isl as f64);
}
let current_duration = self.start_time.elapsed();
if let Some(last_response_time) = self.last_response_time {
let response_duration = current_duration - last_response_time;
let itl = response_duration.as_secs_f64() / num_tokens as f64;
for _ in 0..num_tokens {
self.metrics
.inter_token_latency
.with_label_values(&[&self.model])
.observe(itl);
}
}
self.last_response_time = Some(current_duration);
}
}
impl Drop for ResponseMetricCollector {
fn drop(&mut self) {
// Publish final OSL when the collector is dropped
self.metrics
.output_sequence_length
.with_label_values(&[&self.model])
.observe(self.osl as f64);
}
}
/// Create a new router with the given path /// Create a new router with the given path
pub fn router(registry: Registry, path: Option<String>) -> (Vec<RouteDoc>, Router) { pub fn router(registry: Registry, path: Option<String>) -> (Vec<RouteDoc>, Router) {
let registry = Arc::new(registry); let registry = Arc::new(registry);
......
...@@ -23,7 +23,7 @@ use tokio_stream::wrappers::ReceiverStream; ...@@ -23,7 +23,7 @@ use tokio_stream::wrappers::ReceiverStream;
use super::{ use super::{
error::HttpError, error::HttpError,
metrics::{Endpoint, InflightGuard}, metrics::{Endpoint, InflightGuard, ResponseMetricCollector},
service_v2, RouteDoc, service_v2, RouteDoc,
}; };
...@@ -152,12 +152,13 @@ async fn completions( ...@@ -152,12 +152,13 @@ async fn completions(
.get_completions_engine(model) .get_completions_engine(model)
.map_err(|_| ErrorResponse::model_not_found())?; .map_err(|_| ErrorResponse::model_not_found())?;
// this will increment the inflight gauge for the model let mut inflight_guard =
let mut inflight =
state state
.metrics_clone() .metrics_clone()
.create_inflight_guard(model, Endpoint::Completions, streaming); .create_inflight_guard(model, Endpoint::Completions, streaming);
let mut response_collector = state.metrics_clone().create_response_collector(model);
// setup context // setup context
// todo - inherit request_id from distributed trace details // todo - inherit request_id from distributed trace details
let request = Context::with_id(request, request_id.clone()); let request = Context::with_id(request, request_id.clone());
...@@ -175,8 +176,10 @@ async fn completions( ...@@ -175,8 +176,10 @@ async fn completions(
// note - we might do this as part of the post processing set to make it more generic // note - we might do this as part of the post processing set to make it more generic
if streaming { if streaming {
let stream = stream.map(|response| Event::try_from(EventConverter::from(response))); let stream = stream.map(move |response| {
let stream = monitor_for_disconnects(stream.boxed(), ctx, inflight).await; process_event_converter(EventConverter::from(response), &mut response_collector)
});
let stream = monitor_for_disconnects(stream.boxed(), ctx, inflight_guard).await;
let mut sse_stream = Sse::new(stream); let mut sse_stream = Sse::new(stream);
...@@ -186,6 +189,7 @@ async fn completions( ...@@ -186,6 +189,7 @@ async fn completions(
Ok(sse_stream.into_response()) Ok(sse_stream.into_response())
} else { } else {
// TODO: report ISL/OSL for non-streaming requests
let response = CompletionResponse::from_annotated_stream(stream.into()) let response = CompletionResponse::from_annotated_stream(stream.into())
.await .await
.map_err(|e| { .map_err(|e| {
...@@ -197,7 +201,7 @@ async fn completions( ...@@ -197,7 +201,7 @@ async fn completions(
ErrorResponse::internal_server_error("Failed to fold completions stream") ErrorResponse::internal_server_error("Failed to fold completions stream")
})?; })?;
inflight.mark_ok(); inflight_guard.mark_ok();
Ok(Json(response).into_response()) Ok(Json(response).into_response())
} }
} }
...@@ -269,12 +273,13 @@ async fn chat_completions( ...@@ -269,12 +273,13 @@ async fn chat_completions(
.get_chat_completions_engine(model) .get_chat_completions_engine(model)
.map_err(|_| ErrorResponse::model_not_found())?; .map_err(|_| ErrorResponse::model_not_found())?;
// this will increment the inflight gauge for the model let mut inflight_guard =
let mut inflight =
state state
.metrics_clone() .metrics_clone()
.create_inflight_guard(model, Endpoint::ChatCompletions, streaming); .create_inflight_guard(model, Endpoint::ChatCompletions, streaming);
let mut response_collector = state.metrics_clone().create_response_collector(model);
// setup context // setup context
// todo - inherit request_id from distributed trace details // todo - inherit request_id from distributed trace details
let request = Context::with_id(request, request_id.clone()); let request = Context::with_id(request, request_id.clone());
...@@ -294,8 +299,10 @@ async fn chat_completions( ...@@ -294,8 +299,10 @@ async fn chat_completions(
// note - we might do this as part of the post processing set to make it more generic // note - we might do this as part of the post processing set to make it more generic
if streaming { if streaming {
let stream = stream.map(|response| Event::try_from(EventConverter::from(response))); let stream = stream.map(move |response| {
let stream = monitor_for_disconnects(stream.boxed(), ctx, inflight).await; process_event_converter(EventConverter::from(response), &mut response_collector)
});
let stream = monitor_for_disconnects(stream.boxed(), ctx, inflight_guard).await;
let mut sse_stream = Sse::new(stream); let mut sse_stream = Sse::new(stream);
...@@ -305,6 +312,7 @@ async fn chat_completions( ...@@ -305,6 +312,7 @@ async fn chat_completions(
Ok(sse_stream.into_response()) Ok(sse_stream.into_response())
} else { } else {
// TODO: report ISL/OSL for non-streaming requests
let response = NvCreateChatCompletionResponse::from_annotated_stream(stream.into()) let response = NvCreateChatCompletionResponse::from_annotated_stream(stream.into())
.await .await
.map_err(|e| { .map_err(|e| {
...@@ -319,7 +327,7 @@ async fn chat_completions( ...@@ -319,7 +327,7 @@ async fn chat_completions(
)) ))
})?; })?;
inflight.mark_ok(); inflight_guard.mark_ok();
Ok(Json(response).into_response()) Ok(Json(response).into_response())
} }
} }
...@@ -399,12 +407,11 @@ async fn monitor_for_disconnects( ...@@ -399,12 +407,11 @@ async fn monitor_for_disconnects(
Box<dyn Stream<Item = Result<axum::response::sse::Event, axum::Error>> + std::marker::Send>, Box<dyn Stream<Item = Result<axum::response::sse::Event, axum::Error>> + std::marker::Send>,
>, >,
context: Arc<dyn AsyncEngineContext>, context: Arc<dyn AsyncEngineContext>,
inflight: InflightGuard, mut inflight_guard: InflightGuard,
) -> ReceiverStream<Result<Event, axum::Error>> { ) -> ReceiverStream<Result<Event, axum::Error>> {
let (tx, rx) = tokio::sync::mpsc::channel(8); let (tx, rx) = tokio::sync::mpsc::channel(8);
tokio::spawn(async move { tokio::spawn(async move {
let mut inflight = inflight;
let mut stream = stream; let mut stream = stream;
while let Some(event) = stream.next().await { while let Some(event) = stream.next().await {
let event = match event { let event = match event {
...@@ -419,10 +426,9 @@ async fn monitor_for_disconnects( ...@@ -419,10 +426,9 @@ async fn monitor_for_disconnects(
} }
} }
// the stream completed successfully - mark as ok // Stream completed successfully - mark as ok
// this will increment the request counter with an "success" status
if tx.send(Ok(Event::default().data("[DONE]"))).await.is_ok() { if tx.send(Ok(Event::default().data("[DONE]"))).await.is_ok() {
inflight.mark_ok(); inflight_guard.mark_ok();
} }
}); });
...@@ -437,14 +443,10 @@ impl<T> From<Annotated<T>> for EventConverter<T> { ...@@ -437,14 +443,10 @@ impl<T> From<Annotated<T>> for EventConverter<T> {
} }
} }
/// Convert an Annotated into an Event fn process_event_converter<T: Serialize>(
/// If the Event represents an Error, then return an axum::Error annotated: EventConverter<T>,
/// The [`monitor_for_disconnects`] method will handle the error, emit to the sse stream response_collector: &mut ResponseMetricCollector,
/// then stop the generation of completions. ) -> Result<Event, axum::Error> {
impl<T: Serialize> TryFrom<EventConverter<T>> for Event {
type Error = axum::Error;
fn try_from(annotated: EventConverter<T>) -> Result<Self, Self::Error> {
let annotated = annotated.0; let annotated = annotated.0;
let mut event = Event::default(); let mut event = Event::default();
...@@ -463,6 +465,16 @@ impl<T: Serialize> TryFrom<EventConverter<T>> for Event { ...@@ -463,6 +465,16 @@ impl<T: Serialize> TryFrom<EventConverter<T>> for Event {
event = event.event(msg); event = event.event(msg);
} }
if let Some(osl) = annotated.output_tokens {
response_collector.observe_current_osl(osl);
}
if let Some(isl) = annotated.input_tokens {
if let Some(chunk_tokens) = annotated.chunk_tokens {
response_collector.observe_response(isl, chunk_tokens);
}
}
if let Some(comments) = annotated.comment { if let Some(comments) = annotated.comment {
for comment in comments { for comment in comments {
event = event.comment(comment); event = event.comment(comment);
...@@ -470,7 +482,6 @@ impl<T: Serialize> TryFrom<EventConverter<T>> for Event { ...@@ -470,7 +482,6 @@ impl<T: Serialize> TryFrom<EventConverter<T>> for Event {
} }
Ok(event) Ok(event)
}
} }
/// Create an Axum [`Router`] for the OpenAI API Completions endpoint /// Create an Axum [`Router`] for the OpenAI API Completions endpoint
......
...@@ -193,6 +193,7 @@ impl OpenAIPreprocessor { ...@@ -193,6 +193,7 @@ impl OpenAIPreprocessor {
response_generator: Box<dyn DeltaGeneratorExt<Resp>>, response_generator: Box<dyn DeltaGeneratorExt<Resp>>,
context: Arc<dyn AsyncEngineContext>, context: Arc<dyn AsyncEngineContext>,
cancelled: bool, cancelled: bool,
cumulative_output_tokens: usize,
} }
let state = State { let state = State {
...@@ -200,6 +201,7 @@ impl OpenAIPreprocessor { ...@@ -200,6 +201,7 @@ impl OpenAIPreprocessor {
response_generator: generator, response_generator: generator,
context: context.clone(), context: context.clone(),
cancelled: false, cancelled: false,
cumulative_output_tokens: 0,
}; };
// transform the common response stream into a chat response stream // transform the common response stream into a chat response stream
...@@ -220,7 +222,20 @@ impl OpenAIPreprocessor { ...@@ -220,7 +222,20 @@ impl OpenAIPreprocessor {
response response
); );
let response = response.map_data(|data| { let (chunk_tokens, isl) = if let Some(ref backend_output) = response.data {
let chunk_tokens = backend_output.token_ids.len();
inner.cumulative_output_tokens += chunk_tokens;
let isl = inner.response_generator.get_isl().unwrap_or(0) as usize;
(chunk_tokens, isl)
} else {
(0, 0)
};
let current_osl = inner.cumulative_output_tokens;
let mut response = response.map_data(|data| {
inner inner
.response_generator .response_generator
.choice_from_postprocessor(data) .choice_from_postprocessor(data)
...@@ -236,6 +251,10 @@ impl OpenAIPreprocessor { ...@@ -236,6 +251,10 @@ impl OpenAIPreprocessor {
.map_err(|e| e.to_string()) .map_err(|e| e.to_string())
}); });
response.chunk_tokens = Some(chunk_tokens);
response.input_tokens = Some(isl);
response.output_tokens = Some(current_osl);
tracing::trace!( tracing::trace!(
request_id = inner.context.id(), request_id = inner.context.id(),
"OpenAI NvCreateChatCompletionStreamResponse: {:?}", "OpenAI NvCreateChatCompletionStreamResponse: {:?}",
......
...@@ -118,6 +118,9 @@ where ...@@ -118,6 +118,9 @@ where
data, data,
id: value.id, id: value.id,
event: value.event, event: value.event,
chunk_tokens: None,
input_tokens: None,
output_tokens: None,
comment: value.comments, comment: value.comments,
}) })
} }
......
...@@ -307,6 +307,9 @@ pub trait DeltaGeneratorExt<ResponseType: Send + Sync + 'static + std::fmt::Debu ...@@ -307,6 +307,9 @@ pub trait DeltaGeneratorExt<ResponseType: Send + Sync + 'static + std::fmt::Debu
&mut self, &mut self,
response: common::llm_backend::BackendOutput, response: common::llm_backend::BackendOutput,
) -> Result<ResponseType>; ) -> Result<ResponseType>;
/// Gets the current prompt token count (Input Sequence Length).
fn get_isl(&self) -> Option<u32>;
} }
#[cfg(test)] #[cfg(test)]
......
...@@ -284,6 +284,9 @@ mod tests { ...@@ -284,6 +284,9 @@ mod tests {
data: Some(data), data: Some(data),
id: Some("test_id".to_string()), id: Some("test_id".to_string()),
event: None, event: None,
chunk_tokens: None,
input_tokens: None,
output_tokens: None,
comment: None, comment: None,
} }
} }
...@@ -427,6 +430,9 @@ mod tests { ...@@ -427,6 +430,9 @@ mod tests {
data: Some(data), data: Some(data),
id: Some("test_id".to_string()), id: Some("test_id".to_string()),
event: None, event: None,
chunk_tokens: None,
input_tokens: None,
output_tokens: None,
comment: None, comment: None,
}; };
let stream = Box::pin(stream::iter(vec![annotated_delta])); let stream = Box::pin(stream::iter(vec![annotated_delta]));
......
...@@ -212,4 +212,8 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes ...@@ -212,4 +212,8 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
inner: stream_response, inner: stream_response,
}) })
} }
fn get_isl(&self) -> Option<u32> {
Some(self.usage.prompt_tokens)
}
} }
...@@ -205,6 +205,9 @@ mod tests { ...@@ -205,6 +205,9 @@ mod tests {
}), }),
id: Some("test_id".to_string()), id: Some("test_id".to_string()),
event: None, event: None,
chunk_tokens: None,
input_tokens: None,
output_tokens: None,
comment: None, comment: None,
} }
} }
...@@ -314,6 +317,9 @@ mod tests { ...@@ -314,6 +317,9 @@ mod tests {
}), }),
id: Some("test_id".to_string()), id: Some("test_id".to_string()),
event: None, event: None,
chunk_tokens: None,
input_tokens: None,
output_tokens: None,
comment: None, comment: None,
}; };
......
...@@ -126,4 +126,9 @@ impl crate::protocols::openai::DeltaGeneratorExt<CompletionResponse> for DeltaGe ...@@ -126,4 +126,9 @@ impl crate::protocols::openai::DeltaGeneratorExt<CompletionResponse> for DeltaGe
let index = 0; let index = 0;
Ok(self.create_choice(index, delta.text, finish_reason)) Ok(self.create_choice(index, delta.text, finish_reason))
} }
// TODO: This is a hack. Change `prompt_tokens` to u32
fn get_isl(&self) -> Option<u32> {
Some(self.usage.prompt_tokens as u32)
}
} }
...@@ -37,6 +37,12 @@ pub struct Annotated<R> { ...@@ -37,6 +37,12 @@ pub struct Annotated<R> {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub event: Option<String>, pub event: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub chunk_tokens: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub input_tokens: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_tokens: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub comment: Option<Vec<String>>, pub comment: Option<Vec<String>>,
} }
...@@ -47,6 +53,9 @@ impl<R> Annotated<R> { ...@@ -47,6 +53,9 @@ impl<R> Annotated<R> {
data: None, data: None,
id: None, id: None,
event: Some("error".to_string()), event: Some("error".to_string()),
chunk_tokens: None,
input_tokens: None,
output_tokens: None,
comment: Some(vec![error]), comment: Some(vec![error]),
} }
} }
...@@ -57,6 +66,9 @@ impl<R> Annotated<R> { ...@@ -57,6 +66,9 @@ impl<R> Annotated<R> {
data: Some(data), data: Some(data),
id: None, id: None,
event: None, event: None,
chunk_tokens: None,
input_tokens: None,
output_tokens: None,
comment: None, comment: None,
} }
} }
...@@ -72,6 +84,9 @@ impl<R> Annotated<R> { ...@@ -72,6 +84,9 @@ impl<R> Annotated<R> {
data: None, data: None,
id: None, id: None,
event: Some(name.into()), event: Some(name.into()),
chunk_tokens: None,
input_tokens: None,
output_tokens: None,
comment: Some(vec![serde_json::to_string(value)?]), comment: Some(vec![serde_json::to_string(value)?]),
}) })
} }
...@@ -107,6 +122,9 @@ impl<R> Annotated<R> { ...@@ -107,6 +122,9 @@ impl<R> Annotated<R> {
data, data,
id: self.id, id: self.id,
event: self.event, event: self.event,
chunk_tokens: self.chunk_tokens,
input_tokens: self.input_tokens,
output_tokens: self.output_tokens,
comment: self.comment, comment: self.comment,
} }
} }
...@@ -122,6 +140,9 @@ impl<R> Annotated<R> { ...@@ -122,6 +140,9 @@ impl<R> Annotated<R> {
data, data,
id: self.id, id: self.id,
event: self.event, event: self.event,
chunk_tokens: self.chunk_tokens,
input_tokens: self.input_tokens,
output_tokens: self.output_tokens,
comment: self.comment, comment: self.comment,
}, },
Err(e) => Annotated::from_error(e), Err(e) => Annotated::from_error(e),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment