Unverified Commit e83009a6 authored by Tom O'Brien's avatar Tom O'Brien Committed by GitHub
Browse files

feat: add implementation for embeddings (#1290)

parent 5e9370d3
...@@ -33,6 +33,7 @@ pub async fn run( ...@@ -33,6 +33,7 @@ pub async fn run(
.port(flags.http_port) .port(flags.http_port)
.enable_chat_endpoints(true) .enable_chat_endpoints(true)
.enable_cmpl_endpoints(true) .enable_cmpl_endpoints(true)
.enable_embeddings_endpoints(true)
.with_request_template(template) .with_request_template(template)
.build()?; .build()?;
match engine_config { match engine_config {
......
...@@ -77,6 +77,42 @@ class RequestHandler: ...@@ -77,6 +77,42 @@ class RequestHandler:
num_output_tokens_so_far = next_total_toks num_output_tokens_so_far = next_total_toks
class EmbeddingRequestHandler(RequestHandler):
"""
Request handler for the embedding endpoint
"""
def __init__(self, engine: sglang.Engine, model_name: str):
super().__init__(engine)
self._model_name = model_name
async def generate(self, request):
gen = await self.engine_client.async_encode(prompt=request["input"])
tokens = 0
embeddings = []
for idx, res in enumerate(gen):
embeddings.append(
{
"index": idx,
"object": "embedding",
"embedding": res["embedding"],
}
)
tokens += res["meta_info"]["prompt_tokens"]
out = {
"object": "list",
"model": self._model_name,
"data": embeddings,
"usage": {
"prompt_tokens": tokens,
"total_tokens": tokens,
},
}
yield out
@dynamo_worker(static=False) @dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime): async def worker(runtime: DistributedRuntime):
await init(runtime, cmd_line_args()) await init(runtime, cmd_line_args())
...@@ -129,13 +165,20 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -129,13 +165,20 @@ async def init(runtime: DistributedRuntime, config: Config):
await component.create_service() await component.create_service()
endpoint = component.endpoint(config.endpoint) endpoint = component.endpoint(config.endpoint)
await register_llm( model_type = (
ModelType.Backend, endpoint, config.model_path, config.model_name ModelType.Backend if not engine_args.is_embedding else ModelType.Embedding
) )
await register_llm(model_type, endpoint, config.model_path, config.model_name)
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes) # the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked # after the lease is revoked
await endpoint.serve_endpoint(RequestHandler(engine_client).generate) await endpoint.serve_endpoint(
RequestHandler(engine_client).generate
if not engine_args.is_embedding
else EmbeddingRequestHandler(
engine_client, model_name=config.model_name or config.model_path
).generate
)
def cmd_line_args(): def cmd_line_args():
...@@ -230,7 +273,6 @@ def cmd_line_args(): ...@@ -230,7 +273,6 @@ def cmd_line_args():
config.node_rank = args.node_rank config.node_rank = args.node_rank
config.dist_init_addr = args.dist_init_addr config.dist_init_addr = args.dist_init_addr
config.extra_engine_args = args.extra_engine_args config.extra_engine_args = args.extra_engine_args
return config return config
......
...@@ -129,9 +129,7 @@ impl ModelManager { ...@@ -129,9 +129,7 @@ impl ModelManager {
clients.remove(model) clients.remove(model)
} }
// TODO: Remove this allow once `embeddings` is implemented in lib/llm/src/http/service/openai.rs pub fn get_embeddings_engine(
#[allow(dead_code)]
fn get_embeddings_engine(
&self, &self,
model: &str, model: &str,
) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> { ) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
......
...@@ -27,7 +27,7 @@ use super::{ ...@@ -27,7 +27,7 @@ use super::{
service_v2, RouteDoc, service_v2, RouteDoc,
}; };
use crate::protocols::openai::embeddings::NvCreateEmbeddingRequest; use crate::protocols::openai::embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse};
use crate::protocols::openai::{ use crate::protocols::openai::{
chat_completions::NvCreateChatCompletionResponse, completions::CompletionResponse, chat_completions::NvCreateChatCompletionResponse, completions::CompletionResponse,
}; };
...@@ -208,10 +208,59 @@ async fn completions( ...@@ -208,10 +208,59 @@ async fn completions(
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
async fn embeddings( async fn embeddings(
State(_state): State<Arc<service_v2::State>>, State(state): State<Arc<service_v2::State>>,
Json(_request): Json<NvCreateEmbeddingRequest>, Json(request): Json<NvCreateEmbeddingRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
unimplemented!("embeddings are not supported yet"); // return a 503 if the service is not ready
check_ready(&state)?;
// todo - extract distributed tracing id and context id from headers
let request_id = uuid::Uuid::new_v4().to_string();
// Embeddings are typically not streamed, so we default to non-streaming
let streaming = false;
// todo - make the protocols be optional for model name
// todo - when optional, if none, apply a default
let model = &request.inner.model;
// todo - error handling should be more robust
let engine = state
.manager()
.get_embeddings_engine(model)
.map_err(|_| ErrorResponse::model_not_found())?;
// this will increment the inflight gauge for the model
let mut inflight =
state
.metrics_clone()
.create_inflight_guard(model, Endpoint::Embeddings, streaming);
// setup context
// todo - inherit request_id from distributed trace details
let request = Context::with_id(request, request_id.clone());
// issue the generate call on the engine
let stream = engine
.generate(request)
.await
.map_err(|e| ErrorResponse::from_anyhow(e, "Failed to generate embeddings"))?;
// Embeddings are typically returned as a single response (non-streaming)
// so we fold the stream into a single response
let response = NvCreateEmbeddingResponse::from_annotated_stream(stream.into())
.await
.map_err(|e| {
tracing::error!(
"Failed to fold embeddings stream for {}: {:?}",
request_id,
e
);
ErrorResponse::internal_server_error("Failed to fold embeddings stream")
})?;
inflight.mark_ok();
Ok(Json(response).into_response())
} }
/// OpenAI Chat Completions Request Handler /// OpenAI Chat Completions Request Handler
......
...@@ -75,7 +75,7 @@ pub struct HttpServiceConfig { ...@@ -75,7 +75,7 @@ pub struct HttpServiceConfig {
#[builder(default = "true")] #[builder(default = "true")]
enable_cmpl_endpoints: bool, enable_cmpl_endpoints: bool,
#[builder(default = "false")] #[builder(default = "true")]
enable_embeddings_endpoints: bool, enable_embeddings_endpoints: bool,
#[builder(default = "None")] #[builder(default = "None")]
......
...@@ -16,9 +16,12 @@ ...@@ -16,9 +16,12 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use validator::Validate; use validator::Validate;
mod aggregator;
mod nvext; mod nvext;
pub use nvext::{NvExt, NvExtProvider}; pub use nvext::{NvExt, NvExtProvider};
// pub use delta::DeltaGenerator;
pub use aggregator::DeltaAggregator;
use dynamo_runtime::protocols::annotated::AnnotationsProvider; use dynamo_runtime::protocols::annotated::AnnotationsProvider;
...@@ -59,7 +62,7 @@ impl NvCreateEmbeddingResponse { ...@@ -59,7 +62,7 @@ impl NvCreateEmbeddingResponse {
} }
} }
/// Implements `NvExtProvider` for `NvCr eateEmbeddingRequest`, /// Implements `NvExtProvider` for `NvCreateEmbeddingRequest`,
/// providing access to NVIDIA-specific extensions. /// providing access to NVIDIA-specific extensions.
impl NvExtProvider for NvCreateEmbeddingRequest { impl NvExtProvider for NvCreateEmbeddingRequest {
/// Returns a reference to the optional `NvExt` extension, if available. /// Returns a reference to the optional `NvExt` extension, if available.
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::NvCreateEmbeddingResponse;
use crate::protocols::{
codec::{Message, SseCodecError},
convert_sse_stream, Annotated,
};
use futures::{Stream, StreamExt};
use std::pin::Pin;
/// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`.
type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>;
/// Aggregates a stream of [`NvCreateEmbeddingResponse`]s into a single
/// [`NvCreateEmbeddingResponse`]. For embeddings, this is typically simpler
/// than text generation as embeddings are usually returned as a complete response.
pub struct DeltaAggregator {
/// The accumulated embeddings response.
response: Option<NvCreateEmbeddingResponse>,
/// Optional error message if an error occurs during aggregation.
error: Option<String>,
}
impl Default for DeltaAggregator {
/// Provides a default implementation for `DeltaAggregator` by calling [`DeltaAggregator::new`].
fn default() -> Self {
Self::new()
}
}
impl DeltaAggregator {
/// Creates a new, empty [`DeltaAggregator`] instance.
pub fn new() -> Self {
Self {
response: None,
error: None,
}
}
/// Aggregates a stream of [`NvCreateEmbeddingResponse`]s into a single
/// [`NvCreateEmbeddingResponse`].
///
/// # Arguments
/// * `stream` - A stream of annotated embedding responses.
///
/// # Returns
/// * `Ok(NvCreateEmbeddingResponse)` if aggregation is successful.
/// * `Err(String)` if an error occurs during processing.
pub async fn apply(
stream: DataStream<Annotated<NvCreateEmbeddingResponse>>,
) -> Result<NvCreateEmbeddingResponse, String> {
let aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
// Attempt to unwrap the delta, capturing any errors.
let delta = match delta.ok() {
Ok(delta) => delta,
Err(error) => {
aggregator.error = Some(error);
return aggregator;
}
};
if aggregator.error.is_none() {
if let Some(response) = delta.data {
// For embeddings, we typically expect a single complete response
// or we accumulate data from multiple responses
match &mut aggregator.response {
Some(existing) => {
// Merge embedding data if we have multiple responses
existing.inner.data.extend(response.inner.data);
// Update usage statistics
existing.inner.usage.prompt_tokens +=
response.inner.usage.prompt_tokens;
existing.inner.usage.total_tokens +=
response.inner.usage.total_tokens;
}
None => {
aggregator.response = Some(response);
}
}
}
}
aggregator
})
.await;
// Return early if an error was encountered.
if let Some(error) = aggregator.error {
return Err(error);
}
// Return the aggregated response or an empty response if none was found.
Ok(aggregator
.response
.unwrap_or_else(NvCreateEmbeddingResponse::empty))
}
}
impl NvCreateEmbeddingResponse {
/// Converts an SSE stream into a [`NvCreateEmbeddingResponse`].
///
/// # Arguments
/// * `stream` - A stream of SSE messages containing embedding responses.
///
/// # Returns
/// * `Ok(NvCreateEmbeddingResponse)` if aggregation succeeds.
/// * `Err(String)` if an error occurs.
pub async fn from_sse_stream(
stream: DataStream<Result<Message, SseCodecError>>,
) -> Result<NvCreateEmbeddingResponse, String> {
let stream = convert_sse_stream::<NvCreateEmbeddingResponse>(stream);
NvCreateEmbeddingResponse::from_annotated_stream(stream).await
}
/// Aggregates an annotated stream of embedding responses into a final response.
///
/// # Arguments
/// * `stream` - A stream of annotated embedding responses.
///
/// # Returns
/// * `Ok(NvCreateEmbeddingResponse)` if aggregation succeeds.
/// * `Err(String)` if an error occurs.
pub async fn from_annotated_stream(
stream: DataStream<Annotated<NvCreateEmbeddingResponse>>,
) -> Result<NvCreateEmbeddingResponse, String> {
DeltaAggregator::apply(stream).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::stream;
fn create_test_embedding_response(
embeddings: Vec<async_openai::types::Embedding>,
prompt_tokens: u32,
total_tokens: u32,
) -> Annotated<NvCreateEmbeddingResponse> {
let response = NvCreateEmbeddingResponse {
inner: async_openai::types::CreateEmbeddingResponse {
object: "list".to_string(),
model: "test-model".to_string(),
data: embeddings,
usage: async_openai::types::EmbeddingUsage {
prompt_tokens,
total_tokens,
},
},
};
Annotated::from_data(response)
}
#[tokio::test]
async fn test_empty_stream() {
let stream = stream::empty();
let result = DeltaAggregator::apply(Box::pin(stream)).await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.inner.data.len(), 0);
assert_eq!(response.inner.object, "list");
assert_eq!(response.inner.model, "embedding");
}
#[tokio::test]
async fn test_single_embedding() {
let embedding = async_openai::types::Embedding {
index: 0,
object: "embedding".to_string(),
embedding: vec![0.1, 0.2, 0.3],
};
let annotated = create_test_embedding_response(vec![embedding.clone()], 10, 10);
let stream = stream::iter(vec![annotated]);
let result = DeltaAggregator::apply(Box::pin(stream)).await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.inner.data.len(), 1);
assert_eq!(response.inner.data[0].index, 0);
assert_eq!(response.inner.data[0].embedding, vec![0.1, 0.2, 0.3]);
assert_eq!(response.inner.usage.prompt_tokens, 10);
assert_eq!(response.inner.usage.total_tokens, 10);
}
#[tokio::test]
async fn test_multiple_embeddings() {
let embedding1 = async_openai::types::Embedding {
index: 0,
object: "embedding".to_string(),
embedding: vec![0.1, 0.2, 0.3],
};
let embedding2 = async_openai::types::Embedding {
index: 1,
object: "embedding".to_string(),
embedding: vec![0.4, 0.5, 0.6],
};
let annotated1 = create_test_embedding_response(vec![embedding1.clone()], 5, 5);
let annotated2 = create_test_embedding_response(vec![embedding2.clone()], 7, 7);
let stream = stream::iter(vec![annotated1, annotated2]);
let result = DeltaAggregator::apply(Box::pin(stream)).await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.inner.data.len(), 2);
assert_eq!(response.inner.data[0].index, 0);
assert_eq!(response.inner.data[1].index, 1);
assert_eq!(response.inner.usage.prompt_tokens, 12); // sum of 5 and 7
assert_eq!(response.inner.usage.total_tokens, 12); // sum of 5 and 7
}
#[tokio::test]
async fn test_error_in_stream() {
let error_annotated =
Annotated::<NvCreateEmbeddingResponse>::from_error("Test error".to_string());
let stream = stream::iter(vec![error_annotated]);
let result = DeltaAggregator::apply(Box::pin(stream)).await;
assert!(result.is_err());
assert!(result.unwrap_err().contains("Test error"));
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment