"vscode:/vscode.git/clone" did not exist on "c82f477dcc4b1019f3e9b1ab7366169b58a08030"
Unverified Commit 73fdfb8a authored by Tom O'Brien's avatar Tom O'Brien Committed by GitHub
Browse files

feat: Add OpenAI Embeddings interface in rust lib (#1110)

Implements OpenAI embeddings (interface only).

- Adds ModelType::Embedding
- Adds OpenAI embedding request/response structs
- Adds support for embedding model discovery
parent ac82bcf3
...@@ -66,8 +66,8 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -66,8 +66,8 @@ async fn app(runtime: Runtime) -> Result<()> {
// TODO: A single watcher already watches all model types and does the right thing. // TODO: A single watcher already watches all model types and does the right thing.
// The paths need change here and in llmctl to not include the model_type // The paths need change here and in llmctl to not include the model_type
// Create watchers for `Chat` and `Completion` model types // Create watchers for `Chat`, `Completion`, and `Embedding` model types
for model_type in [ModelType::Chat, ModelType::Completion] { for model_type in [ModelType::Chat, ModelType::Completion, ModelType::Embedding] {
let etcd_path = format!("{}/models/{}/", etcd_root, model_type.as_str()); let etcd_path = format!("{}/models/{}/", etcd_root, model_type.as_str());
let watch_obj = Arc::new( let watch_obj = Arc::new(
......
...@@ -152,6 +152,11 @@ pub async fn prepare_engine( ...@@ -152,6 +152,11 @@ pub async fn prepare_engine(
) )
*/ */
} }
ModelType::Embedding => {
anyhow::bail!(
"text and batch input only accept remote Chat models, not Embedding"
);
}
}; };
// The service_name isn't used for text chat outside of logs, // The service_name isn't used for text chat outside of logs,
// so use the path. That avoids having to listen on etcd for model registration. // so use the path. That avoids having to listen on etcd for model registration.
......
...@@ -92,6 +92,12 @@ define_type_subcommands!( ...@@ -92,6 +92,12 @@ define_type_subcommands!(
"Add a completion model" "Add a completion model"
), ),
// Add new model types here: // Add new model types here:
(
Embedding,
"embedding",
["embeddings", "embedding-model"],
"Add an embedding model"
)
); );
#[derive(Parser)] #[derive(Parser)]
......
...@@ -113,6 +113,7 @@ fn register_llm<'p>( ...@@ -113,6 +113,7 @@ fn register_llm<'p>(
ModelType::Chat => llm_rs::model_type::ModelType::Chat, ModelType::Chat => llm_rs::model_type::ModelType::Chat,
ModelType::Completion => llm_rs::model_type::ModelType::Completion, ModelType::Completion => llm_rs::model_type::ModelType::Completion,
ModelType::Backend => llm_rs::model_type::ModelType::Backend, ModelType::Backend => llm_rs::model_type::ModelType::Backend,
ModelType::Embedding => llm_rs::model_type::ModelType::Embedding,
}; };
let inner_path = model_path.to_string(); let inner_path = model_path.to_string();
...@@ -192,6 +193,7 @@ enum ModelType { ...@@ -192,6 +193,7 @@ enum ModelType {
Chat = 1, Chat = 1,
Completion = 2, Completion = 2,
Backend = 3, Backend = 3,
Embedding = 4,
} }
#[pymethods] #[pymethods]
......
...@@ -38,6 +38,7 @@ use dynamo_runtime::protocols::annotated::Annotated; ...@@ -38,6 +38,7 @@ use dynamo_runtime::protocols::annotated::Annotated;
use dynamo_llm::protocols::openai::{ use dynamo_llm::protocols::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{prompt_to_string, CompletionRequest, CompletionResponse}, completions::{prompt_to_string, CompletionRequest, CompletionResponse},
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
}; };
use dynamo_llm::engines::{EngineDispatcher, StreamingEngine}; use dynamo_llm::engines::{EngineDispatcher, StreamingEngine};
...@@ -600,3 +601,19 @@ fn is_gemma3(s: &str) -> bool { ...@@ -600,3 +601,19 @@ fn is_gemma3(s: &str) -> bool {
fn is_llama4(s: &str) -> bool { fn is_llama4(s: &str) -> bool {
s.to_lowercase().contains("llama-4") s.to_lowercase().contains("llama-4")
} }
#[async_trait]
impl
AsyncEngine<
SingleIn<NvCreateEmbeddingRequest>,
ManyOut<Annotated<NvCreateEmbeddingResponse>>,
Error,
> for MistralRsEngine
{
async fn generate(
&self,
_request: SingleIn<NvCreateEmbeddingRequest>,
) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> {
unimplemented!()
}
}
...@@ -32,6 +32,8 @@ use crate::protocols::openai::{ ...@@ -32,6 +32,8 @@ use crate::protocols::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{prompt_to_string, CompletionRequest, CompletionResponse}, completions::{prompt_to_string, CompletionRequest, CompletionResponse},
}; };
use crate::types::openai::embeddings::NvCreateEmbeddingRequest;
use crate::types::openai::embeddings::NvCreateEmbeddingResponse;
// //
// The engines are each in their own crate under `lib/engines` // The engines are each in their own crate under `lib/engines`
...@@ -147,8 +149,19 @@ pub trait StreamingEngine: Send + Sync { ...@@ -147,8 +149,19 @@ pub trait StreamingEngine: Send + Sync {
) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error>; ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error>;
} }
/// Trait that allows handling embedding requests
#[async_trait]
pub trait EmbeddingEngine: Send + Sync {
async fn handle_embedding(
&self,
req: SingleIn<NvCreateEmbeddingRequest>,
) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error>;
}
pub fn make_engine_full() -> Arc<dyn StreamingEngine> { pub fn make_engine_full() -> Arc<dyn StreamingEngine> {
Arc::new(EngineDispatcher::new(EchoEngineFull {})) let engine = EchoEngineFull {};
let data = EngineDispatcher::new(engine);
Arc::new(data)
} }
#[async_trait] #[async_trait]
...@@ -233,6 +246,22 @@ impl AsyncEngine<SingleIn<CompletionRequest>, ManyOut<Annotated<CompletionRespon ...@@ -233,6 +246,22 @@ impl AsyncEngine<SingleIn<CompletionRequest>, ManyOut<Annotated<CompletionRespon
} }
} }
#[async_trait]
impl
AsyncEngine<
SingleIn<NvCreateEmbeddingRequest>,
ManyOut<Annotated<NvCreateEmbeddingResponse>>,
Error,
> for EchoEngineFull
{
async fn generate(
&self,
_incoming_request: SingleIn<NvCreateEmbeddingRequest>,
) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> {
unimplemented!()
}
}
#[async_trait] #[async_trait]
impl<E> StreamingEngine for EngineDispatcher<E> impl<E> StreamingEngine for EngineDispatcher<E>
where where
...@@ -241,6 +270,10 @@ where ...@@ -241,6 +270,10 @@ where
SingleIn<NvCreateChatCompletionRequest>, SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
Error, Error,
> + AsyncEngine<
SingleIn<NvCreateEmbeddingRequest>,
ManyOut<Annotated<NvCreateEmbeddingResponse>>,
Error,
> + Send > + Send
+ Sync, + Sync,
{ {
...@@ -259,6 +292,48 @@ where ...@@ -259,6 +292,48 @@ where
} }
} }
#[async_trait]
impl<E> EmbeddingEngine for EngineDispatcher<E>
where
E: AsyncEngine<
SingleIn<NvCreateEmbeddingRequest>,
ManyOut<Annotated<NvCreateEmbeddingResponse>>,
Error,
> + Send
+ Sync,
{
async fn handle_embedding(
&self,
req: SingleIn<NvCreateEmbeddingRequest>,
) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> {
self.inner.generate(req).await
}
}
pub struct EmbeddingEngineAdapter(Arc<dyn EmbeddingEngine>);
impl EmbeddingEngineAdapter {
pub fn new(engine: Arc<dyn EmbeddingEngine>) -> Self {
EmbeddingEngineAdapter(engine)
}
}
#[async_trait]
impl
AsyncEngine<
SingleIn<NvCreateEmbeddingRequest>,
ManyOut<Annotated<NvCreateEmbeddingResponse>>,
Error,
> for EmbeddingEngineAdapter
{
async fn generate(
&self,
req: SingleIn<NvCreateEmbeddingRequest>,
) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> {
self.0.handle_embedding(req).await
}
}
pub struct StreamingEngineAdapter(Arc<dyn StreamingEngine>); pub struct StreamingEngineAdapter(Arc<dyn StreamingEngine>);
impl StreamingEngineAdapter { impl StreamingEngineAdapter {
......
...@@ -47,7 +47,7 @@ pub use metrics::Metrics; ...@@ -47,7 +47,7 @@ pub use metrics::Metrics;
use crate::types::openai::{ use crate::types::openai::{
chat_completions::OpenAIChatCompletionsStreamingEngine, chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine, completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
}; };
use std::{ use std::{
collections::HashMap, collections::HashMap,
...@@ -116,6 +116,15 @@ impl ModelManager { ...@@ -116,6 +116,15 @@ impl ModelManager {
clients.add(model, engine) clients.add(model, engine)
} }
pub fn add_embeddings_model(
&self,
model: &str,
engine: OpenAIEmbeddingsStreamingEngine,
) -> Result<(), ServiceHttpError> {
let mut clients = self.state.embeddings_engines.lock().unwrap();
clients.add(model, engine)
}
pub fn remove_completions_model(&self, model: &str) -> Result<(), ServiceHttpError> { pub fn remove_completions_model(&self, model: &str) -> Result<(), ServiceHttpError> {
let mut clients = self.state.completion_engines.lock().unwrap(); let mut clients = self.state.completion_engines.lock().unwrap();
clients.remove(model) clients.remove(model)
...@@ -126,6 +135,11 @@ impl ModelManager { ...@@ -126,6 +135,11 @@ impl ModelManager {
clients.remove(model) clients.remove(model)
} }
pub fn remove_embeddings_model(&self, model: &str) -> Result<(), ServiceHttpError> {
let mut clients = self.state.embeddings_engines.lock().unwrap();
clients.remove(model)
}
/// Get the Prometheus [`Metrics`] object which tracks request counts and inflight requests /// Get the Prometheus [`Metrics`] object which tracks request counts and inflight requests
pub fn metrics(&self) -> Arc<Metrics> { pub fn metrics(&self) -> Arc<Metrics> {
self.state.metrics.clone() self.state.metrics.clone()
...@@ -191,6 +205,7 @@ impl<E> ModelEngines<E> { ...@@ -191,6 +205,7 @@ impl<E> ModelEngines<E> {
pub struct DeploymentState { pub struct DeploymentState {
completion_engines: Arc<Mutex<ModelEngines<OpenAICompletionsStreamingEngine>>>, completion_engines: Arc<Mutex<ModelEngines<OpenAICompletionsStreamingEngine>>>,
chat_completion_engines: Arc<Mutex<ModelEngines<OpenAIChatCompletionsStreamingEngine>>>, chat_completion_engines: Arc<Mutex<ModelEngines<OpenAIChatCompletionsStreamingEngine>>>,
embeddings_engines: Arc<Mutex<ModelEngines<OpenAIEmbeddingsStreamingEngine>>>,
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
sse_keep_alive: Option<Duration>, sse_keep_alive: Option<Duration>,
} }
...@@ -200,11 +215,26 @@ impl DeploymentState { ...@@ -200,11 +215,26 @@ impl DeploymentState {
Self { Self {
completion_engines: Arc::new(Mutex::new(ModelEngines::default())), completion_engines: Arc::new(Mutex::new(ModelEngines::default())),
chat_completion_engines: Arc::new(Mutex::new(ModelEngines::default())), chat_completion_engines: Arc::new(Mutex::new(ModelEngines::default())),
embeddings_engines: Arc::new(Mutex::new(ModelEngines::default())),
metrics: Arc::new(Metrics::default()), metrics: Arc::new(Metrics::default()),
sse_keep_alive: None, sse_keep_alive: None,
} }
} }
// TODO: Remove this allow once `embeddings` is implemented in lib/llm/src/http/service/openai.rs
#[allow(dead_code)]
fn get_embeddings_engine(
&self,
model: &str,
) -> Result<OpenAIEmbeddingsStreamingEngine, ServiceHttpError> {
self.embeddings_engines
.lock()
.unwrap()
.get(model)
.cloned()
.ok_or(ServiceHttpError::ModelNotFound(model.to_string()))
}
fn get_completions_engine( fn get_completions_engine(
&self, &self,
model: &str, model: &str,
......
...@@ -37,6 +37,7 @@ use crate::{ ...@@ -37,6 +37,7 @@ use crate::{
protocols::openai::chat_completions::{ protocols::openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
}, },
protocols::openai::embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
}; };
use tracing; use tracing;
...@@ -240,6 +241,7 @@ impl ModelWatcher { ...@@ -240,6 +241,7 @@ impl ModelWatcher {
// Ignore the errors because model could be either type // Ignore the errors because model could be either type
let _ = self.manager.remove_chat_completions_model(model_name); let _ = self.manager.remove_chat_completions_model(model_name);
let _ = self.manager.remove_completions_model(model_name); let _ = self.manager.remove_completions_model(model_name);
let _ = self.manager.remove_embeddings_model(model_name);
Ok(model_name) Ok(model_name)
} }
...@@ -376,6 +378,16 @@ impl ModelWatcher { ...@@ -376,6 +378,16 @@ impl ModelWatcher {
self.manager self.manager
.add_completions_model(&model_entry.name, engine)?; .add_completions_model(&model_entry.name, engine)?;
} }
ModelType::Embedding => {
let push_router = PushRouter::<
NvCreateEmbeddingRequest,
Annotated<NvCreateEmbeddingResponse>,
>::from_client(client, Default::default())
.await?;
let engine = Arc::new(push_router);
self.manager
.add_embeddings_model(&model_entry.name, engine)?;
}
} }
Ok(()) Ok(())
......
...@@ -60,6 +60,9 @@ pub enum Endpoint { ...@@ -60,6 +60,9 @@ pub enum Endpoint {
/// OAI Chat Completions /// OAI Chat Completions
ChatCompletions, ChatCompletions,
/// OAI Embeddings
Embeddings,
} }
/// Metrics for the HTTP service /// Metrics for the HTTP service
...@@ -276,6 +279,7 @@ impl std::fmt::Display for Endpoint { ...@@ -276,6 +279,7 @@ impl std::fmt::Display for Endpoint {
match self { match self {
Endpoint::Completions => write!(f, "completions"), Endpoint::Completions => write!(f, "completions"),
Endpoint::ChatCompletions => write!(f, "chat_completions"), Endpoint::ChatCompletions => write!(f, "chat_completions"),
Endpoint::Embeddings => write!(f, "embeddings"),
} }
} }
} }
...@@ -285,6 +289,7 @@ impl Endpoint { ...@@ -285,6 +289,7 @@ impl Endpoint {
match self { match self {
Endpoint::Completions => "completions", Endpoint::Completions => "completions",
Endpoint::ChatCompletions => "chat_completions", Endpoint::ChatCompletions => "chat_completions",
Endpoint::Embeddings => "embeddings",
} }
} }
} }
......
...@@ -40,6 +40,7 @@ use super::{ ...@@ -40,6 +40,7 @@ use super::{
RouteDoc, RouteDoc,
}; };
use crate::protocols::openai::embeddings::NvCreateEmbeddingRequest;
use crate::protocols::openai::{ use crate::protocols::openai::{
chat_completions::NvCreateChatCompletionResponse, completions::CompletionResponse, chat_completions::NvCreateChatCompletionResponse, completions::CompletionResponse,
}; };
...@@ -210,6 +211,14 @@ async fn completions( ...@@ -210,6 +211,14 @@ async fn completions(
} }
} }
#[tracing::instrument(skip_all)]
async fn embeddings(
State(_state): State<Arc<DeploymentState>>,
Json(_request): Json<NvCreateEmbeddingRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
unimplemented!("embeddings are not supported yet");
}
/// OpenAI Chat Completions Request Handler /// OpenAI Chat Completions Request Handler
/// ///
/// This method will handle the incoming request for the /v1/chat/completions endpoint. The endpoint is a "source" /// This method will handle the incoming request for the /v1/chat/completions endpoint. The endpoint is a "source"
...@@ -391,6 +400,7 @@ async fn list_models_openai( ...@@ -391,6 +400,7 @@ async fn list_models_openai(
.engines .engines
.keys() .keys()
.chain(state.completion_engines.lock().unwrap().engines.keys()) .chain(state.completion_engines.lock().unwrap().engines.keys())
.chain(state.embeddings_engines.lock().unwrap().engines.keys())
.cloned() .cloned()
.collect(); .collect();
...@@ -538,6 +548,20 @@ pub fn chat_completions_router( ...@@ -538,6 +548,20 @@ pub fn chat_completions_router(
(vec![doc], router) (vec![doc], router)
} }
/// Create an Axum [`Router`] for the OpenAI API Embeddings endpoint
/// If not path is provided, the default path is `/v1/embeddings`
pub fn embeddings_router(
state: Arc<DeploymentState>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let path = path.unwrap_or("/v1/embeddings".to_string());
let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new()
.route(&path, post(embeddings))
.with_state(state);
(vec![doc], router)
}
/// List Models /// List Models
pub fn list_models_router( pub fn list_models_router(
state: Arc<DeploymentState>, state: Arc<DeploymentState>,
......
...@@ -48,6 +48,9 @@ pub struct HttpServiceConfig { ...@@ -48,6 +48,9 @@ pub struct HttpServiceConfig {
#[builder(default = "true")] #[builder(default = "true")]
enable_cmpl_endpoints: bool, enable_cmpl_endpoints: bool,
#[builder(default = "false")]
enable_embeddings_endpoints: bool,
#[builder(default = "None")] #[builder(default = "None")]
request_template: Option<RequestTemplate>, request_template: Option<RequestTemplate>,
} }
...@@ -93,7 +96,7 @@ impl HttpService { ...@@ -93,7 +96,7 @@ impl HttpService {
impl HttpServiceConfigBuilder { impl HttpServiceConfigBuilder {
pub fn build(self) -> Result<HttpService, anyhow::Error> { pub fn build(self) -> Result<HttpService, anyhow::Error> {
let config = self.build_internal()?; let config: HttpServiceConfig = self.build_internal()?;
let model_manager = ModelManager::new(); let model_manager = ModelManager::new();
...@@ -125,6 +128,13 @@ impl HttpServiceConfigBuilder { ...@@ -125,6 +128,13 @@ impl HttpServiceConfigBuilder {
)); ));
} }
if config.enable_embeddings_endpoints {
routes.push(super::openai::embeddings_router(
model_manager.state(),
None,
));
}
// for (route_docs, route) in routes.into_iter().chain(self.routes.into_iter()) { // for (route_docs, route) in routes.into_iter().chain(self.routes.into_iter()) {
// router = router.merge(route); // router = router.merge(route);
// all_docs.extend(route_docs); // all_docs.extend(route_docs);
......
...@@ -22,6 +22,8 @@ pub enum ModelType { ...@@ -22,6 +22,8 @@ pub enum ModelType {
Chat, Chat,
/// Older completions API /// Older completions API
Completion, Completion,
/// Embeddings API
Embedding,
// Pre-processed requests // Pre-processed requests
Backend, Backend,
} }
...@@ -31,11 +33,12 @@ impl ModelType { ...@@ -31,11 +33,12 @@ impl ModelType {
match self { match self {
Self::Chat => "chat", Self::Chat => "chat",
Self::Completion => "completion", Self::Completion => "completion",
Self::Embedding => "embedding",
Self::Backend => "backend", Self::Backend => "backend",
} }
} }
pub fn all() -> Vec<Self> { pub fn all() -> Vec<Self> {
vec![Self::Chat, Self::Completion, Self::Backend] vec![Self::Chat, Self::Completion, Self::Embedding, Self::Backend]
} }
} }
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
pub mod chat_completions; pub mod chat_completions;
pub mod completions; pub mod completions;
pub mod embeddings;
pub mod models; pub mod models;
pub mod nvext; pub mod nvext;
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use serde::{Deserialize, Serialize};
use validator::Validate;
mod nvext;
pub use nvext::{NvExt, NvExtProvider};
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateEmbeddingRequest {
#[serde(flatten)]
pub inner: async_openai::types::CreateEmbeddingRequest,
#[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<NvExt>,
}
/// A response structure for unary chat completion responses, embedding OpenAI's
/// `CreateChatCompletionResponse`.
///
/// # Fields
/// - `inner`: The base OpenAI unary chat completion response, embedded
/// using `serde(flatten)`.
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateEmbeddingResponse {
#[serde(flatten)]
pub inner: async_openai::types::CreateEmbeddingResponse,
}
impl NvCreateEmbeddingResponse {
pub fn empty() -> Self {
Self {
inner: async_openai::types::CreateEmbeddingResponse {
object: "list".to_string(),
model: "embedding".to_string(),
data: vec![],
usage: async_openai::types::EmbeddingUsage {
prompt_tokens: 0,
total_tokens: 0,
},
},
}
}
}
/// Implements `NvExtProvider` for `NvCr eateEmbeddingRequest`,
/// providing access to NVIDIA-specific extensions.
impl NvExtProvider for NvCreateEmbeddingRequest {
/// Returns a reference to the optional `NvExt` extension, if available.
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
}
/// Implements `AnnotationsProvider` for `NvCreateEmbeddingRequest`,
/// enabling retrieval and management of request annotations.
impl AnnotationsProvider for NvCreateEmbeddingRequest {
/// Retrieves the list of annotations from `NvExt`, if present.
fn annotations(&self) -> Option<Vec<String>> {
self.nvext
.as_ref()
.and_then(|nvext| nvext.annotations.clone())
}
/// Checks whether a specific annotation exists in the request.
///
/// # Arguments
/// * `annotation` - A string slice representing the annotation to check.
///
/// # Returns
/// `true` if the annotation exists, `false` otherwise.
fn has_annotation(&self, annotation: &str) -> bool {
self.nvext
.as_ref()
.and_then(|nvext| nvext.annotations.as_ref())
.map(|annotations| annotations.contains(&annotation.to_string()))
.unwrap_or(false)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use validator::{Validate, ValidationError};
pub trait NvExtProvider {
fn nvext(&self) -> Option<&NvExt>;
}
/// NVIDIA LLM extensions to the OpenAI API
#[derive(Serialize, Deserialize, Builder, Validate, Debug, Clone)]
#[validate(schema(function = "validate_nv_ext"))]
pub struct NvExt {
/// Annotations
/// User requests triggers which result in the request issue back out-of-band information in the SSE
/// stream using the `event:` field.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub annotations: Option<Vec<String>>,
}
impl Default for NvExt {
fn default() -> Self {
NvExt::builder().build().unwrap()
}
}
impl NvExt {
pub fn builder() -> NvExtBuilder {
NvExtBuilder::default()
}
}
fn validate_nv_ext(_nv_ext: &NvExt) -> Result<(), ValidationError> {
Ok(())
}
impl NvExtBuilder {
pub fn add_annotation(&mut self, annotation: impl Into<String>) -> &mut Self {
self.annotations
.get_or_insert_with(|| Some(vec![]))
.as_mut()
.expect("stop should always be Some(Vec)")
.push(annotation.into());
self
}
}
#[cfg(test)]
mod tests {
use super::*;
// Test default builder configuration
#[test]
fn test_nv_ext_builder_default() {
let nv_ext = NvExt::builder().build().unwrap();
assert_eq!(nv_ext.annotations, None);
}
}
...@@ -52,4 +52,20 @@ pub mod openai { ...@@ -52,4 +52,20 @@ pub mod openai {
Annotated<NvCreateChatCompletionStreamResponse>, Annotated<NvCreateChatCompletionStreamResponse>,
>; >;
} }
pub mod embeddings {
use super::*;
pub use protocols::openai::embeddings::{
NvCreateEmbeddingRequest, NvCreateEmbeddingResponse,
};
/// A [`UnaryEngine`] implementation for the OpenAI Embeddings API
pub type OpenAIEmbeddingsUnaryEngine =
UnaryEngine<NvCreateEmbeddingRequest, NvCreateEmbeddingResponse>;
/// A [`ServerStreamingEngine`] implementation for the OpenAI Embeddings API
pub type OpenAIEmbeddingsStreamingEngine =
ServerStreamingEngine<NvCreateEmbeddingRequest, Annotated<NvCreateEmbeddingResponse>>;
}
} }
...@@ -138,6 +138,7 @@ fn compute_index(endpoint: &Endpoint, request_type: &RequestType, status: &Statu ...@@ -138,6 +138,7 @@ fn compute_index(endpoint: &Endpoint, request_type: &RequestType, status: &Statu
let endpoint = match endpoint { let endpoint = match endpoint {
Endpoint::Completions => 0, Endpoint::Completions => 0,
Endpoint::ChatCompletions => 1, Endpoint::ChatCompletions => 1,
Endpoint::Embeddings => todo!(),
}; };
let request_type = match request_type { let request_type = match request_type {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment