Unverified Commit 9d7c5df5 authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: remove dead protocols code and organize imports idiomatically (#1669)

parent 03d976c7
...@@ -19,9 +19,10 @@ ...@@ -19,9 +19,10 @@
//! both publicly via the HTTP API and internally between Dynamo components. //! both publicly via the HTTP API and internally between Dynamo components.
//! //!
use std::pin::Pin;
use futures::{Stream, StreamExt}; use futures::{Stream, StreamExt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::pin::Pin;
pub mod codec; pub mod codec;
pub mod common; pub mod common;
...@@ -48,13 +49,6 @@ pub trait ContentProvider { ...@@ -48,13 +49,6 @@ pub trait ContentProvider {
fn content(&self) -> String; fn content(&self) -> String;
} }
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Usage {
pub prompt_tokens: i32,
pub completion_tokens: i32,
pub total_tokens: i32,
}
/// Converts of a stream of [codec::Message]s into a stream of [Annotated]s. /// Converts of a stream of [codec::Message]s into a stream of [Annotated]s.
pub fn convert_sse_stream<R>( pub fn convert_sse_stream<R>(
stream: DataStream<Result<codec::Message, codec::SseCodecError>>, stream: DataStream<Result<codec::Message, codec::SseCodecError>>,
......
...@@ -23,10 +23,11 @@ ...@@ -23,10 +23,11 @@
// TODO: Determine if we should use an External EventSource crate. There appear to be several // TODO: Determine if we should use an External EventSource crate. There appear to be several
// potential candidates. // potential candidates.
use std::{io::Cursor, pin::Pin};
use bytes::BytesMut; use bytes::BytesMut;
use futures::Stream; use futures::Stream;
use serde::Deserialize; use serde::Deserialize;
use std::{io::Cursor, pin::Pin};
use tokio_util::codec::{Decoder, FramedRead, LinesCodec}; use tokio_util::codec::{Decoder, FramedRead, LinesCodec};
use super::Annotated; use super::Annotated;
......
...@@ -24,13 +24,10 @@ ...@@ -24,13 +24,10 @@
//! need some additional information to propagate intermediate results for improved observability. //! need some additional information to propagate intermediate results for improved observability.
//! The metadata is transferred via the other arms of the `StreamingResponse` enum. //! The metadata is transferred via the other arms of the `StreamingResponse` enum.
//! //!
use std::collections::HashMap;
use std::time::SystemTime;
use anyhow::Result; use anyhow::Result;
use derive_builder::Builder; use derive_builder::Builder;
use serde::ser::SerializeStruct; use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use super::TokenIdType; use super::TokenIdType;
...@@ -416,68 +413,6 @@ pub struct TopLogprob { ...@@ -416,68 +413,6 @@ pub struct TopLogprob {
pub bytes: Option<Vec<u8>>, pub bytes: Option<Vec<u8>>,
} }
// /// UserData is a struct that contains user-defined data that can be passed to the inference engine.
// /// This information will be use to annotate the distributed traces for improved observability.
// #[derive(Serialize, Deserialize, Debug, Clone, Default)]
// pub struct UserData {
// /// Apply server-side prompt template to the request
// pub request_uuid: Option<uuid::Uuid>,
// }
/// StreamingResponse is the primary response object for the LLM Engine. The response stream
/// can emit three different types of messages. The Initialize and Finalize messages are optional
/// and primarily used over disaggreated transports to move states from the server to the client.
#[derive(Serialize, Deserialize, Debug)]
pub enum StreamingResponse {
/// Initialize transports a Prologue object which communication the LLM Engine Context
Initialize(Option<Prologue>),
/// Step is the primary data in the response stream. It contains the StreamingCompletionResponse
Step(Box<StreamingCompletionResponse>),
/// Finalize is an optional final message in the response stream. It contains the Epilogue object which
/// is used to communicate extra information about the completion and the engine statistics.
Finalize(Option<Epilogue>),
}
// TODO(ryan) - this should be part of the internal api as it is not deserializble
// the public API should drop the Option<Arc<Stats>> in favor of Option<Stats>
// the two variants both serialize to the same json; however, the internal version
// can not be deserialized directly.
// we use the internal one on the server side to avoid the cost of cloning the Stats
// object; however, client side, we should always fully materialize the Stats object.
//
// TODO(ryan) - update this object to use an enum where we have the current definition be the
// StepResponse arm; then we will add the following arms:
// - Initialize(Prologue)
// - Step()
// - Finalize(Epilogue)
/// This is the first message that will be emitted by an Engine Response Stream
/// It indicates that the request has been preprocessed and queued for execution on the backend.
#[derive(Serialize, Deserialize, Debug)]
pub struct Prologue {
/// If the request was preprocessed with a prompt template, this will contain the formatted prompt
pub formatted_prompt: Option<String>,
/// If the request did not contain TokenIds, this will contain the token_ids that were generated
/// from tokenizing the prompt.
pub input_token_ids: Option<Vec<TokenIdType>>,
}
/// This is the final message that will be emitted by a Engine Response Stream when it
/// finishes without error. In some cases, the engine may emit an error which will indicate
/// the end of the steam. Another case in which an Finalize(Epilogue) will not be emitted is
/// if the response handler has stalled and too many responses
#[derive(Serialize, Deserialize, Debug)]
pub struct Epilogue {}
#[derive(Debug)]
pub struct StreamingCompletionResponse {
pub delta: Delta,
pub logprobs: Option<ChatCompletionLogprobs>,
}
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub enum StreamState { pub enum StreamState {
Active, Active,
...@@ -506,6 +441,12 @@ pub struct SequencePositionData { ...@@ -506,6 +441,12 @@ pub struct SequencePositionData {
pub logprobs: Option<LogProbs>, pub logprobs: Option<LogProbs>,
} }
#[derive(Debug)]
pub struct StreamingCompletionResponse {
pub delta: Delta,
pub logprobs: Option<ChatCompletionLogprobs>,
}
// todo(ryan) - we need to create a DeltaBuilder which is a mutable object that can be passed // todo(ryan) - we need to create a DeltaBuilder which is a mutable object that can be passed
// around from the low-level compute engine to the high-level api. The DeltaBuilder will allow // around from the low-level compute engine to the high-level api. The DeltaBuilder will allow
// us to construct the Delta object at multiple layers in the streaming response path. // us to construct the Delta object at multiple layers in the streaming response path.
...@@ -549,134 +490,6 @@ pub struct Usage { ...@@ -549,134 +490,6 @@ pub struct Usage {
pub output_tokens_count: usize, pub output_tokens_count: usize,
} }
// todo(ryan) - we need to update this object to make it more generic
// we need to define a set of generic stats traits that allow those stats to be None
// then back them by a concrete implementation like a TrtllmStats object
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Stats {
/// Time since the last Epoch/Forward Pass in microseconds (us).
/// This is measured and recorded by the Response Router rather then the
/// Inference Engine. Note, when evaluating the responses, if the this
/// values is greater then the stream's measured value, then there was a gap
/// between forward passes. In normal operation, the value of this field should
/// be less than the recorded value on the response stream.
pub time_since_last_forward_pass_us: Option<u64>,
pub request_active_count: u32,
pub request_context_count: u32,
pub request_generation_count: u32,
pub request_scheduled_count: u32,
pub request_max_count: u32,
pub kv_free_cache_blocks: u64,
pub kv_max_cache_blocks: u64,
pub kv_used_cache_blocks: u64,
pub kv_tokens_per_cache_block: u64,
pub runtime_cpu_memory_usage: u64,
pub runtime_gpu_memory_usage: u64,
pub runtime_pinned_memory_usage: u64,
pub iteration_counter: u64,
pub microbatch_id: u64,
pub total_context_tokens: u32,
pub timestamp: String,
}
impl Serialize for StreamingCompletionResponse {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut state = serializer.serialize_struct("StreamingCompletionResponse", 2)?;
// Serialize `delta` field
state.serialize_field("delta", &self.delta)?;
state.end()
}
}
impl<'de> Deserialize<'de> for StreamingCompletionResponse {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
// Create a temporary struct for deserialization
#[derive(Deserialize)]
struct TempResponse {
delta: Delta,
logprobs: Option<ChatCompletionLogprobs>,
}
let TempResponse { delta, logprobs } = TempResponse::deserialize(deserializer)?;
Ok(StreamingCompletionResponse { delta, logprobs })
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ScatterData<T> {
pub x: Vec<T>,
pub y: Vec<T>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Trace {
pub time_to_first_token: u64,
pub token_to_token: Vec<u64>,
pub start: SystemTime,
pub complete: SystemTime,
pub initial_tokens: u32,
pub max_tokens: u32,
pub t2ft_iteration_count: u64,
pub t2t_iteration_count: Vec<u64>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct PerformanceModel {
// linear regression parameters fitting t2ft vs. initial tokens
pub t2ft_intercept: f64,
pub t2ft_slope: f64,
// linear regression parameters fitting t2tl vs. initial tokens
pub t2tl_intercept: f64,
pub t2tl_slope: f64,
// r2 values from the regression
pub t2ft_fit_r2: f64,
pub t2tl_fit_r2: f64,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct CalibrationResults {
pub effective_flops: f64,
pub effective_memory_bandwidth: f64,
pub max_q: u32,
pub performance_model: PerformanceModel,
pub traces: Vec<Trace>,
pub t2ft_scatter_data: ScatterData<f64>,
pub t2tl_scatter_data: ScatterData<f64>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct LoadgenResults {
pub stats_by_iteration: HashMap<u64, Stats>,
pub traces: Vec<Trace>,
}
impl CompletionContext { impl CompletionContext {
/// Create a new CompletionContext /// Create a new CompletionContext
pub fn new(prompt: String, system_prompt: Option<String>) -> Self { pub fn new(prompt: String, system_prompt: Option<String>) -> Self {
...@@ -712,7 +525,6 @@ impl From<CompletionContext> for PromptType { ...@@ -712,7 +525,6 @@ impl From<CompletionContext> for PromptType {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use serde_json;
use super::*; use super::*;
...@@ -759,360 +571,4 @@ mod tests { ...@@ -759,360 +571,4 @@ mod tests {
panic!("Expected a Completion variant"); panic!("Expected a Completion variant");
} }
} }
// #[test]
// fn test_serialize_with_stats() {
// let response = StreamingCompletionResponse {
// delta: Delta {
// is_complete: true,
// finish_reason: Some(FinishReason::Length),
// token_ids: Some(vec![101, 102, 103]),
// tokens: Some(vec!["token1".to_string(), "token2".to_string()]),
// text: Some("example text".to_string()),
// sequence_length: Some(3),
// index: Some(0),
// cum_log_probs: Some(-0.5),
// err_msg: None,
// usage: None,
// },
// logprobs: None,
// };
// // Serialize the response
// let serialized = serde_json::to_string(&response).expect("Failed to serialize");
// // Expected JSON string (simplified)
// let expected = r#"{
// "delta": {
// "is_complete": true,
// "finish_reason": "length",
// "token_ids": [101, 102, 103],
// "tokens": ["token1", "token2"],
// "text": "example text",
// "sequence_length": 3,
// "index": 0,
// "cum_log_probs": -0.5,
// "err_msg": null,
// "usage": null
// },
// "stats": {
// "time_since_last_forward_pass_us": 1000,
// "request_active_count": 2,
// "request_context_count": 1,
// "request_generation_count": 3,
// "request_scheduled_count": 1,
// "request_max_count": 10,
// "kv_free_cache_blocks": 500,
// "kv_max_cache_blocks": 1000,
// "kv_used_cache_blocks": 500,
// "kv_tokens_per_cache_block": 10,
// "runtime_cpu_memory_usage": 5000,
// "runtime_gpu_memory_usage": 2000,
// "runtime_pinned_memory_usage": 1000,
// "iteration_counter": 5,
// "microbatch_id": 12345,
// "total_context_tokens": 256,
// "timestamp": "2024-01-01T00:00:00Z"
// }
// }"#;
// assert_eq!(
// serde_json::from_str::<serde_json::Value>(&serialized).unwrap(),
// serde_json::from_str::<serde_json::Value>(expected).unwrap()
// );
// }
#[test]
fn test_serialize_without_stats() {
let response = StreamingCompletionResponse {
delta: Delta {
is_complete: false,
finish_reason: None,
token_ids: None,
tokens: None,
text: None,
sequence_length: None,
index: None,
cum_log_probs: None,
err_msg: None,
usage: None,
},
logprobs: None,
};
// Serialize the response
let serialized = serde_json::to_string(&response).expect("Failed to serialize");
// Expected JSON string
let expected = r#"{
"delta": {
"is_complete": false,
"finish_reason": null,
"token_ids": null,
"tokens": null,
"text": null,
"sequence_length": null,
"index": null,
"cum_log_probs": null,
"err_msg": null,
"usage": null
}
}"#;
assert_eq!(
serde_json::from_str::<serde_json::Value>(&serialized).unwrap(),
serde_json::from_str::<serde_json::Value>(expected).unwrap()
);
}
// #[test]
// fn test_deserialize_with_stats() {
// let json_data = r#"{
// "delta": {
// "is_complete": true,
// "finish_reason": "length",
// "token_ids": [101, 102, 103],
// "tokens": ["token1", "token2"],
// "text": "example text",
// "sequence_length": 3,
// "index": 0,
// "cum_log_probs": -0.5,
// "err_msg": null,
// "usage": null
// },
// "stats": {
// "time_since_last_forward_pass_us": 1000,
// "request_active_count": 2,
// "request_context_count": 1,
// "request_generation_count": 3,
// "request_scheduled_count": 1,
// "request_max_count": 10,
// "kv_free_cache_blocks": 500,
// "kv_max_cache_blocks": 1000,
// "kv_used_cache_blocks": 500,
// "kv_tokens_per_cache_block": 10,
// "runtime_cpu_memory_usage": 5000,
// "runtime_gpu_memory_usage": 2000,
// "runtime_pinned_memory_usage": 1000,
// "iteration_counter": 5,
// "microbatch_id": 12345,
// "total_context_tokens": 256,
// "timestamp": "2024-01-01T00:00:00Z"
// }
// }"#;
// // Deserialize the JSON string
// let deserialized: StreamingCompletionResponse =
// serde_json::from_str(json_data).expect("Failed to deserialize");
// // Expected response object
// let expected = StreamingCompletionResponse {
// delta: Delta {
// is_complete: true,
// finish_reason: Some(FinishReason::Length),
// token_ids: Some(vec![101, 102, 103]),
// tokens: Some(vec!["token1".to_string(), "token2".to_string()]),
// text: Some("example text".to_string()),
// sequence_length: Some(3),
// index: Some(0),
// cum_log_probs: Some(-0.5),
// err_msg: None,
// usage: None,
// },
// logprobs: None,
// };
// // This is wieldy but we can no longer do assert_eq!(deserialized, expected);
// // because the struct no longer has the PartialEq trait
// assert_eq!(deserialized.delta.is_complete, expected.delta.is_complete);
// assert_eq!(
// deserialized.delta.finish_reason,
// expected.delta.finish_reason
// );
// assert_eq!(deserialized.delta.token_ids, expected.delta.token_ids);
// assert_eq!(deserialized.delta.tokens, expected.delta.tokens);
// assert_eq!(deserialized.delta.text, expected.delta.text);
// assert_eq!(
// deserialized.delta.sequence_length,
// expected.delta.sequence_length
// );
// assert_eq!(deserialized.delta.index, expected.delta.index);
// assert_eq!(
// deserialized.delta.cum_log_probs,
// expected.delta.cum_log_probs
// );
// assert_eq!(deserialized.delta.err_msg, expected.delta.err_msg);
// assert_eq!(deserialized.delta.usage, expected.delta.usage);
// assert_eq!(
// deserialized_stats.time_since_last_forward_pass_us,
// expected_stats.time_since_last_forward_pass_us
// );
// assert_eq!(
// deserialized_stats.request_active_count,
// expected_stats.request_active_count
// );
// assert_eq!(
// deserialized_stats.request_context_count,
// expected_stats.request_context_count
// );
// assert_eq!(
// deserialized_stats.request_generation_count,
// expected_stats.request_generation_count
// );
// assert_eq!(
// deserialized_stats.request_scheduled_count,
// expected_stats.request_scheduled_count
// );
// assert_eq!(
// deserialized_stats.request_max_count,
// expected_stats.request_max_count
// );
// assert_eq!(
// deserialized_stats.kv_free_cache_blocks,
// expected_stats.kv_free_cache_blocks
// );
// assert_eq!(
// deserialized_stats.kv_max_cache_blocks,
// expected_stats.kv_max_cache_blocks
// );
// assert_eq!(
// deserialized_stats.kv_used_cache_blocks,
// expected_stats.kv_used_cache_blocks
// );
// assert_eq!(
// deserialized_stats.kv_tokens_per_cache_block,
// expected_stats.kv_tokens_per_cache_block
// );
// assert_eq!(
// deserialized_stats.runtime_cpu_memory_usage,
// expected_stats.runtime_cpu_memory_usage
// );
// assert_eq!(
// deserialized_stats.runtime_gpu_memory_usage,
// expected_stats.runtime_gpu_memory_usage
// );
// assert_eq!(
// deserialized_stats.runtime_pinned_memory_usage,
// expected_stats.runtime_pinned_memory_usage
// );
// assert_eq!(
// deserialized_stats.iteration_counter,
// expected_stats.iteration_counter
// );
// assert_eq!(
// deserialized_stats.microbatch_id,
// expected_stats.microbatch_id
// );
// assert_eq!(
// deserialized_stats.total_context_tokens,
// expected_stats.total_context_tokens
// );
// assert_eq!(deserialized_stats.timestamp, expected_stats.timestamp);
// }
#[test]
fn test_deserialize_without_stats() {
let json_data = r#"{
"delta": {
"is_complete": false,
"finish_reason": null,
"token_ids": null,
"tokens": null,
"text": null,
"sequence_length": null,
"index": null,
"cum_log_probs": null,
"err_msg": null,
"usage": null
}
}"#;
// Deserialize the JSON string
let deserialized: StreamingCompletionResponse =
serde_json::from_str(json_data).expect("Failed to deserialize");
// Expected response object
let expected = StreamingCompletionResponse {
delta: Delta {
is_complete: false,
finish_reason: None,
token_ids: None,
tokens: None,
text: None,
sequence_length: None,
index: None,
cum_log_probs: None,
err_msg: None,
usage: None,
},
logprobs: None,
};
// This is wieldy but we can no longer do assert_eq!(deserialized, expected);
// because the struct no longer has the PartialEq trait
assert_eq!(deserialized.delta.is_complete, expected.delta.is_complete);
assert_eq!(
deserialized.delta.finish_reason,
expected.delta.finish_reason
);
assert_eq!(deserialized.delta.token_ids, expected.delta.token_ids);
assert_eq!(deserialized.delta.tokens, expected.delta.tokens);
assert_eq!(deserialized.delta.text, expected.delta.text);
assert_eq!(
deserialized.delta.sequence_length,
expected.delta.sequence_length
);
assert_eq!(deserialized.delta.index, expected.delta.index);
assert_eq!(
deserialized.delta.cum_log_probs,
expected.delta.cum_log_probs
);
assert_eq!(deserialized.delta.err_msg, expected.delta.err_msg);
assert_eq!(deserialized.delta.usage, expected.delta.usage);
}
#[test]
fn test_serialize_delta_and_none_stats() {
let response = StreamingCompletionResponse {
delta: Delta {
is_complete: true,
finish_reason: Some(FinishReason::Length),
token_ids: Some(vec![101, 102, 103]),
tokens: Some(vec!["token1".to_string(), "token2".to_string()]),
text: Some("example text".to_string()),
sequence_length: Some(3),
index: Some(0),
cum_log_probs: Some(-0.5),
err_msg: None,
usage: None,
},
logprobs: None,
};
// Serialize the response
let serialized = serde_json::to_string(&response).expect("Failed to serialize");
// Expected JSON string where stats is null
let expected_json = r#"{
"delta": {
"is_complete": true,
"finish_reason": "length",
"token_ids": [101, 102, 103],
"tokens": ["token1", "token2"],
"text": "example text",
"sequence_length": 3,
"index": 0,
"cum_log_probs": -0.5,
"err_msg": null,
"usage": null
}
}"#;
// Parse both the serialized response and the expected JSON as serde_json::Value for easy comparison
assert_eq!(
serde_json::from_str::<serde_json::Value>(&serialized).unwrap(),
serde_json::from_str::<serde_json::Value>(expected_json).unwrap()
);
}
} }
...@@ -15,14 +15,13 @@ ...@@ -15,14 +15,13 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub use super::preprocessor::PreprocessedRequest;
pub use super::FinishReason;
use crate::protocols::TokenIdType; use crate::protocols::TokenIdType;
pub type TokenType = Option<String>; pub type TokenType = Option<String>;
pub type LogProbs = Vec<f64>; pub type LogProbs = Vec<f64>;
pub use super::preprocessor::PreprocessedRequest;
pub use super::FinishReason;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct BackendOutput { pub struct BackendOutput {
/// New token_ids generated from the LLM Engine /// New token_ids generated from the LLM Engine
......
...@@ -13,24 +13,22 @@ ...@@ -13,24 +13,22 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
pub mod chat_completions; use std::fmt::Display;
pub mod completions;
pub mod embeddings;
pub mod models;
pub mod nvext;
use anyhow::Result; use anyhow::Result;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{
fmt::Display,
ops::{Add, Div, Mul, Sub},
};
use super::{ use super::{
common::{self, SamplingOptionsProvider, StopConditionsProvider}, common::{self, SamplingOptionsProvider, StopConditionsProvider},
ContentProvider, ContentProvider,
}; };
pub mod chat_completions;
pub mod completions;
pub mod embeddings;
pub mod models;
pub mod nvext;
/// Minimum allowed value for OpenAI's `temperature` sampling option /// Minimum allowed value for OpenAI's `temperature` sampling option
pub const MIN_TEMPERATURE: f32 = 0.0; pub const MIN_TEMPERATURE: f32 = 0.0;
...@@ -67,22 +65,6 @@ pub const MAX_PRESENCE_PENALTY: f32 = 2.0; ...@@ -67,22 +65,6 @@ pub const MAX_PRESENCE_PENALTY: f32 = 2.0;
/// Allowed range of values for OpenAI's `presence_penalty` sampling option /// Allowed range of values for OpenAI's `presence_penalty` sampling option
pub const PRESENCE_PENALTY_RANGE: (f32, f32) = (MIN_PRESENCE_PENALTY, MAX_PRESENCE_PENALTY); pub const PRESENCE_PENALTY_RANGE: (f32, f32) = (MIN_PRESENCE_PENALTY, MAX_PRESENCE_PENALTY);
/// Represents a streaming response from the OpenAI API
/// The object is generalized on R, which is the type of the response.
/// For SSE streaming responses, the expected `data: ` field is always a JSON
/// object corresponding to `R`; however, the comments in the SSE stream `: `
/// may correspond to other types of information, such as performance metrics,
/// as represented by other arms of this enum.
///
/// This is part of the common API as both the client and service need to agree
/// on the format of the streaming responses.
#[derive(Serialize, Deserialize, Debug)]
pub enum StreamingDelta<R> {
/// Represents a response delta from the API
Delta(R),
Comment(String),
}
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct AnnotatedDelta<R> { pub struct AnnotatedDelta<R> {
pub delta: R, pub delta: R,
...@@ -183,43 +165,6 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T { ...@@ -183,43 +165,6 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
} }
} }
/// Common structure for chat completion responses; the only delta is the type of choices which differs
/// between streaming and non-streaming requests.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct GenericCompletionResponse<C>
// where
// C: Serialize + Clone,
{
/// A unique identifier for the chat completion.
pub id: String,
/// A list of chat completion choices. Can be more than one if n is greater than 1.
pub choices: Vec<C>,
/// The Unix timestamp (in seconds) of when the chat completion was created.
pub created: u64,
/// The model used for the chat completion.
pub model: String,
/// The object type, which is `chat.completion` if the type of `Choice` is `ChatCompletionChoice`,
/// or is `chat.completion.chunk` if the type of `Choice` is `ChatCompletionChoiceDelta`.
pub object: String,
pub usage: Option<async_openai::types::CompletionUsage>,
/// This fingerprint represents the backend configuration that the model runs with.
///
/// Can be used in conjunction with the seed request parameter to understand when backend changes
/// have been made that might impact determinism.
///
/// NIM Compatibility:
/// This field is not supported by the NIM; however it will be added in the future.
/// The optional nature of this field will be relaxed when it is supported.
pub system_fingerprint: Option<String>,
// TODO() - add NvResponseExtention
}
// todo - move to common location // todo - move to common location
fn validate_range<T>(value: Option<T>, range: &(T, T)) -> Result<Option<T>> fn validate_range<T>(value: Option<T>, range: &(T, T)) -> Result<Option<T>>
where where
...@@ -235,30 +180,6 @@ where ...@@ -235,30 +180,6 @@ where
Ok(Some(value)) Ok(Some(value))
} }
// todo - move to common location
/// scale value in `src` range to `dst` range
pub fn scale_value<T>(value: &T, src: &(T, T), dst: &(T, T)) -> Result<T>
where
T: Copy
+ PartialOrd
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ From<f32>,
{
let dst_range = dst.1 - dst.0;
let src_range = src.1 - src.0;
if dst_range == T::from(0.0) {
anyhow::bail!("dst range is 0");
}
if src_range == T::from(0.0) {
anyhow::bail!("src range is 0");
}
let value_scaled = (*value - src.0) / src_range;
Ok(dst.0 + (value_scaled * dst_range))
}
pub trait DeltaGeneratorExt<ResponseType: Send + Sync + 'static + std::fmt::Debug>: pub trait DeltaGeneratorExt<ResponseType: Send + Sync + 'static + std::fmt::Debug>:
Send + Sync + 'static Send + Sync + 'static
{ {
...@@ -270,37 +191,3 @@ pub trait DeltaGeneratorExt<ResponseType: Send + Sync + 'static + std::fmt::Debu ...@@ -270,37 +191,3 @@ pub trait DeltaGeneratorExt<ResponseType: Send + Sync + 'static + std::fmt::Debu
/// Gets the current prompt token count (Input Sequence Length). /// Gets the current prompt token count (Input Sequence Length).
fn get_isl(&self) -> Option<u32>; fn get_isl(&self) -> Option<u32>;
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_range() {
assert_eq!(validate_range(Some(0.5), &(0.0, 1.0)).unwrap(), Some(0.5));
assert_eq!(validate_range(Some(0.0), &(0.0, 1.0)).unwrap(), Some(0.0));
assert_eq!(validate_range(Some(1.0), &(1.0, 1.0)).unwrap(), Some(1.0));
assert_eq!(validate_range(Some(1_i32), &(1, 1)).unwrap(), Some(1));
assert_eq!(
validate_range(Some(1.1), &(0.0, 1.0))
.unwrap_err()
.to_string(),
"Value 1.1 is out of range [0, 1]"
);
assert_eq!(
validate_range(Some(-0.1), &(0.0, 1.0))
.unwrap_err()
.to_string(),
"Value -0.1 is out of range [0, 1]"
);
}
#[test]
fn test_scaled_value() {
assert_eq!(scale_value(&0.5, &(0.0, 1.0), &(0.0, 2.0)).unwrap(), 1.0);
assert_eq!(scale_value(&0.0, &(0.0, 1.0), &(0.0, 2.0)).unwrap(), 0.0);
assert_eq!(scale_value(&-1.0, &(-2.0, 2.0), &(1.0, 2.0)).unwrap(), 1.25);
assert!(scale_value(&1.0, &(1.0, 1.0), &(0.0, 2.0)).is_err());
}
}
...@@ -13,13 +13,14 @@ ...@@ -13,13 +13,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
use serde::{Deserialize, Serialize};
use validator::Validate;
use super::nvext::NvExt; use super::nvext::NvExt;
use super::nvext::NvExtProvider; use super::nvext::NvExtProvider;
use super::OpenAISamplingOptionsProvider; use super::OpenAISamplingOptionsProvider;
use super::OpenAIStopConditionsProvider; use super::OpenAIStopConditionsProvider;
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
use serde::{Deserialize, Serialize};
use validator::Validate;
mod aggregator; mod aggregator;
mod delta; mod delta;
......
...@@ -13,15 +13,16 @@ ...@@ -13,15 +13,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{collections::HashMap, pin::Pin};
use futures::{Stream, StreamExt};
use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse}; use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse};
use crate::protocols::{ use crate::protocols::{
codec::{Message, SseCodecError}, codec::{Message, SseCodecError},
convert_sse_stream, Annotated, convert_sse_stream, Annotated,
}; };
use futures::{Stream, StreamExt};
use std::{collections::HashMap, pin::Pin};
/// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`. /// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`.
type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>; type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>;
......
...@@ -13,17 +13,15 @@ ...@@ -13,17 +13,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use validator::Validate; use validator::Validate;
mod aggregator; mod aggregator;
mod nvext; mod nvext;
pub use nvext::{NvExt, NvExtProvider};
// pub use delta::DeltaGenerator;
pub use aggregator::DeltaAggregator; pub use aggregator::DeltaAggregator;
pub use nvext::{NvExt, NvExtProvider};
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
#[derive(Serialize, Deserialize, Validate, Debug, Clone)] #[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateEmbeddingRequest { pub struct NvCreateEmbeddingRequest {
......
...@@ -13,15 +13,16 @@ ...@@ -13,15 +13,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::pin::Pin;
use futures::{Stream, StreamExt};
use super::NvCreateEmbeddingResponse; use super::NvCreateEmbeddingResponse;
use crate::protocols::{ use crate::protocols::{
codec::{Message, SseCodecError}, codec::{Message, SseCodecError},
convert_sse_stream, Annotated, convert_sse_stream, Annotated,
}; };
use futures::{Stream, StreamExt};
use std::pin::Pin;
/// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`. /// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`.
type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>; type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment