Unverified Commit 9d7c5df5 authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: remove dead protocols code and organize imports idiomatically (#1669)

parent 03d976c7
...@@ -19,9 +19,10 @@ ...@@ -19,9 +19,10 @@
//! both publicly via the HTTP API and internally between Dynamo components. //! both publicly via the HTTP API and internally between Dynamo components.
//! //!
use std::pin::Pin;
use futures::{Stream, StreamExt}; use futures::{Stream, StreamExt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::pin::Pin;
pub mod codec; pub mod codec;
pub mod common; pub mod common;
...@@ -48,13 +49,6 @@ pub trait ContentProvider { ...@@ -48,13 +49,6 @@ pub trait ContentProvider {
fn content(&self) -> String; fn content(&self) -> String;
} }
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Usage {
pub prompt_tokens: i32,
pub completion_tokens: i32,
pub total_tokens: i32,
}
/// Converts of a stream of [codec::Message]s into a stream of [Annotated]s. /// Converts of a stream of [codec::Message]s into a stream of [Annotated]s.
pub fn convert_sse_stream<R>( pub fn convert_sse_stream<R>(
stream: DataStream<Result<codec::Message, codec::SseCodecError>>, stream: DataStream<Result<codec::Message, codec::SseCodecError>>,
......
...@@ -23,10 +23,11 @@ ...@@ -23,10 +23,11 @@
// TODO: Determine if we should use an External EventSource crate. There appear to be several // TODO: Determine if we should use an External EventSource crate. There appear to be several
// potential candidates. // potential candidates.
use std::{io::Cursor, pin::Pin};
use bytes::BytesMut; use bytes::BytesMut;
use futures::Stream; use futures::Stream;
use serde::Deserialize; use serde::Deserialize;
use std::{io::Cursor, pin::Pin};
use tokio_util::codec::{Decoder, FramedRead, LinesCodec}; use tokio_util::codec::{Decoder, FramedRead, LinesCodec};
use super::Annotated; use super::Annotated;
......
This diff is collapsed.
...@@ -15,14 +15,13 @@ ...@@ -15,14 +15,13 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub use super::preprocessor::PreprocessedRequest;
pub use super::FinishReason;
use crate::protocols::TokenIdType; use crate::protocols::TokenIdType;
pub type TokenType = Option<String>; pub type TokenType = Option<String>;
pub type LogProbs = Vec<f64>; pub type LogProbs = Vec<f64>;
pub use super::preprocessor::PreprocessedRequest;
pub use super::FinishReason;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct BackendOutput { pub struct BackendOutput {
/// New token_ids generated from the LLM Engine /// New token_ids generated from the LLM Engine
......
...@@ -13,24 +13,22 @@ ...@@ -13,24 +13,22 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
pub mod chat_completions; use std::fmt::Display;
pub mod completions;
pub mod embeddings;
pub mod models;
pub mod nvext;
use anyhow::Result; use anyhow::Result;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{
fmt::Display,
ops::{Add, Div, Mul, Sub},
};
use super::{ use super::{
common::{self, SamplingOptionsProvider, StopConditionsProvider}, common::{self, SamplingOptionsProvider, StopConditionsProvider},
ContentProvider, ContentProvider,
}; };
pub mod chat_completions;
pub mod completions;
pub mod embeddings;
pub mod models;
pub mod nvext;
/// Minimum allowed value for OpenAI's `temperature` sampling option /// Minimum allowed value for OpenAI's `temperature` sampling option
pub const MIN_TEMPERATURE: f32 = 0.0; pub const MIN_TEMPERATURE: f32 = 0.0;
...@@ -67,22 +65,6 @@ pub const MAX_PRESENCE_PENALTY: f32 = 2.0; ...@@ -67,22 +65,6 @@ pub const MAX_PRESENCE_PENALTY: f32 = 2.0;
/// Allowed range of values for OpenAI's `presence_penalty` sampling option /// Allowed range of values for OpenAI's `presence_penalty` sampling option
pub const PRESENCE_PENALTY_RANGE: (f32, f32) = (MIN_PRESENCE_PENALTY, MAX_PRESENCE_PENALTY); pub const PRESENCE_PENALTY_RANGE: (f32, f32) = (MIN_PRESENCE_PENALTY, MAX_PRESENCE_PENALTY);
/// Represents a streaming response from the OpenAI API
/// The object is generalized on R, which is the type of the response.
/// For SSE streaming responses, the expected `data: ` field is always a JSON
/// object corresponding to `R`; however, the comments in the SSE stream `: `
/// may correspond to other types of information, such as performance metrics,
/// as represented by other arms of this enum.
///
/// This is part of the common API as both the client and service need to agree
/// on the format of the streaming responses.
#[derive(Serialize, Deserialize, Debug)]
pub enum StreamingDelta<R> {
/// Represents a response delta from the API
Delta(R),
Comment(String),
}
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct AnnotatedDelta<R> { pub struct AnnotatedDelta<R> {
pub delta: R, pub delta: R,
...@@ -183,43 +165,6 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T { ...@@ -183,43 +165,6 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
} }
} }
/// Common structure for chat completion responses; the only delta is the type of choices which differs
/// between streaming and non-streaming requests.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct GenericCompletionResponse<C>
// where
// C: Serialize + Clone,
{
/// A unique identifier for the chat completion.
pub id: String,
/// A list of chat completion choices. Can be more than one if n is greater than 1.
pub choices: Vec<C>,
/// The Unix timestamp (in seconds) of when the chat completion was created.
pub created: u64,
/// The model used for the chat completion.
pub model: String,
/// The object type, which is `chat.completion` if the type of `Choice` is `ChatCompletionChoice`,
/// or is `chat.completion.chunk` if the type of `Choice` is `ChatCompletionChoiceDelta`.
pub object: String,
pub usage: Option<async_openai::types::CompletionUsage>,
/// This fingerprint represents the backend configuration that the model runs with.
///
/// Can be used in conjunction with the seed request parameter to understand when backend changes
/// have been made that might impact determinism.
///
/// NIM Compatibility:
/// This field is not supported by the NIM; however it will be added in the future.
/// The optional nature of this field will be relaxed when it is supported.
pub system_fingerprint: Option<String>,
// TODO() - add NvResponseExtention
}
// todo - move to common location // todo - move to common location
fn validate_range<T>(value: Option<T>, range: &(T, T)) -> Result<Option<T>> fn validate_range<T>(value: Option<T>, range: &(T, T)) -> Result<Option<T>>
where where
...@@ -235,30 +180,6 @@ where ...@@ -235,30 +180,6 @@ where
Ok(Some(value)) Ok(Some(value))
} }
// todo - move to common location
/// scale value in `src` range to `dst` range
pub fn scale_value<T>(value: &T, src: &(T, T), dst: &(T, T)) -> Result<T>
where
T: Copy
+ PartialOrd
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ From<f32>,
{
let dst_range = dst.1 - dst.0;
let src_range = src.1 - src.0;
if dst_range == T::from(0.0) {
anyhow::bail!("dst range is 0");
}
if src_range == T::from(0.0) {
anyhow::bail!("src range is 0");
}
let value_scaled = (*value - src.0) / src_range;
Ok(dst.0 + (value_scaled * dst_range))
}
pub trait DeltaGeneratorExt<ResponseType: Send + Sync + 'static + std::fmt::Debug>: pub trait DeltaGeneratorExt<ResponseType: Send + Sync + 'static + std::fmt::Debug>:
Send + Sync + 'static Send + Sync + 'static
{ {
...@@ -270,37 +191,3 @@ pub trait DeltaGeneratorExt<ResponseType: Send + Sync + 'static + std::fmt::Debu ...@@ -270,37 +191,3 @@ pub trait DeltaGeneratorExt<ResponseType: Send + Sync + 'static + std::fmt::Debu
/// Gets the current prompt token count (Input Sequence Length). /// Gets the current prompt token count (Input Sequence Length).
fn get_isl(&self) -> Option<u32>; fn get_isl(&self) -> Option<u32>;
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_range() {
assert_eq!(validate_range(Some(0.5), &(0.0, 1.0)).unwrap(), Some(0.5));
assert_eq!(validate_range(Some(0.0), &(0.0, 1.0)).unwrap(), Some(0.0));
assert_eq!(validate_range(Some(1.0), &(1.0, 1.0)).unwrap(), Some(1.0));
assert_eq!(validate_range(Some(1_i32), &(1, 1)).unwrap(), Some(1));
assert_eq!(
validate_range(Some(1.1), &(0.0, 1.0))
.unwrap_err()
.to_string(),
"Value 1.1 is out of range [0, 1]"
);
assert_eq!(
validate_range(Some(-0.1), &(0.0, 1.0))
.unwrap_err()
.to_string(),
"Value -0.1 is out of range [0, 1]"
);
}
#[test]
fn test_scaled_value() {
assert_eq!(scale_value(&0.5, &(0.0, 1.0), &(0.0, 2.0)).unwrap(), 1.0);
assert_eq!(scale_value(&0.0, &(0.0, 1.0), &(0.0, 2.0)).unwrap(), 0.0);
assert_eq!(scale_value(&-1.0, &(-2.0, 2.0), &(1.0, 2.0)).unwrap(), 1.25);
assert!(scale_value(&1.0, &(1.0, 1.0), &(0.0, 2.0)).is_err());
}
}
...@@ -13,13 +13,14 @@ ...@@ -13,13 +13,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
use serde::{Deserialize, Serialize};
use validator::Validate;
use super::nvext::NvExt; use super::nvext::NvExt;
use super::nvext::NvExtProvider; use super::nvext::NvExtProvider;
use super::OpenAISamplingOptionsProvider; use super::OpenAISamplingOptionsProvider;
use super::OpenAIStopConditionsProvider; use super::OpenAIStopConditionsProvider;
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
use serde::{Deserialize, Serialize};
use validator::Validate;
mod aggregator; mod aggregator;
mod delta; mod delta;
......
...@@ -13,15 +13,16 @@ ...@@ -13,15 +13,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{collections::HashMap, pin::Pin};
use futures::{Stream, StreamExt};
use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse}; use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse};
use crate::protocols::{ use crate::protocols::{
codec::{Message, SseCodecError}, codec::{Message, SseCodecError},
convert_sse_stream, Annotated, convert_sse_stream, Annotated,
}; };
use futures::{Stream, StreamExt};
use std::{collections::HashMap, pin::Pin};
/// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`. /// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`.
type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>; type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>;
......
...@@ -13,17 +13,15 @@ ...@@ -13,17 +13,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use validator::Validate; use validator::Validate;
mod aggregator; mod aggregator;
mod nvext; mod nvext;
pub use nvext::{NvExt, NvExtProvider};
// pub use delta::DeltaGenerator;
pub use aggregator::DeltaAggregator; pub use aggregator::DeltaAggregator;
pub use nvext::{NvExt, NvExtProvider};
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
#[derive(Serialize, Deserialize, Validate, Debug, Clone)] #[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateEmbeddingRequest { pub struct NvCreateEmbeddingRequest {
......
...@@ -13,15 +13,16 @@ ...@@ -13,15 +13,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::pin::Pin;
use futures::{Stream, StreamExt};
use super::NvCreateEmbeddingResponse; use super::NvCreateEmbeddingResponse;
use crate::protocols::{ use crate::protocols::{
codec::{Message, SseCodecError}, codec::{Message, SseCodecError},
convert_sse_stream, Annotated, convert_sse_stream, Annotated,
}; };
use futures::{Stream, StreamExt};
use std::pin::Pin;
/// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`. /// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`.
type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>; type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment