Unverified Commit 47e7fde7 authored by Tom O'Brien's avatar Tom O'Brien Committed by GitHub
Browse files

feat: Implement frontend tokenization for embedding requests (#1494)

parent 36f03d40
...@@ -13,6 +13,7 @@ from typing import Optional ...@@ -13,6 +13,7 @@ from typing import Optional
import sglang import sglang
import uvloop import uvloop
from sglang.srt.entrypoints.engine import EmbeddingReqInput
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from dynamo.llm import ModelType, register_llm from dynamo.llm import ModelType, register_llm
...@@ -126,38 +127,21 @@ class RequestHandler: ...@@ -126,38 +127,21 @@ class RequestHandler:
yield out yield out
async def encode(self, request):
obj = EmbeddingReqInput(input_ids=request["token_ids"])
generator = self.engine_client.tokenizer_manager.generate_request(obj, None)
engine_results = await anext(generator)
class EmbeddingRequestHandler(RequestHandler):
"""
Request handler for the embedding endpoint
"""
def __init__(self, engine: sglang.Engine, model_name: str):
super().__init__(engine)
self._model_name = model_name
async def generate(self, request):
gen = await self.engine_client.async_encode(prompt=request["input"])
tokens = 0 tokens = 0
embeddings = [] embeddings = []
for idx, res in enumerate(gen): for result in engine_results:
embeddings.append( embeddings.append(result["embedding"])
{ tokens += result["meta_info"]["prompt_tokens"]
"index": idx,
"object": "embedding",
"embedding": res["embedding"],
}
)
tokens += res["meta_info"]["prompt_tokens"]
out = { out = {
"object": "list", "embeddings": embeddings,
"model": self._model_name, "prompt_tokens": tokens,
"data": embeddings, "total_tokens": tokens,
"usage": {
"prompt_tokens": tokens,
"total_tokens": tokens,
},
} }
yield out yield out
...@@ -222,13 +206,11 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -222,13 +206,11 @@ async def init(runtime: DistributedRuntime, config: Config):
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes) # the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked # after the lease is revoked
await endpoint.serve_endpoint( handler = RequestHandler(engine_client)
RequestHandler(engine_client).generate if engine_args.is_embedding:
if not engine_args.is_embedding await endpoint.serve_endpoint(handler.encode)
else EmbeddingRequestHandler( else:
engine_client, model_name=config.model_name or config.model_path await endpoint.serve_endpoint(handler.generate)
).generate
)
def cmd_line_args(): def cmd_line_args():
......
...@@ -44,7 +44,11 @@ use dynamo_runtime::{ ...@@ -44,7 +44,11 @@ use dynamo_runtime::{
use crate::protocols::{ use crate::protocols::{
common::{ common::{
llm_backend::{BackendOutput, FinishReason, LLMEngineOutput, PreprocessedRequest}, llm_backend::{
BackendOutput, EmbeddingsEngineOutput, FinishReason, LLMEngineOutput,
PreprocessedRequest,
},
preprocessor::PreprocessedEmbeddingRequest,
StopConditions, StopConditions,
}, },
TokenIdType, TokenIdType,
...@@ -233,6 +237,36 @@ impl ...@@ -233,6 +237,36 @@ impl
} }
} }
#[async_trait]
impl
Operator<
SingleIn<PreprocessedEmbeddingRequest>,
ManyOut<Annotated<EmbeddingsEngineOutput>>,
SingleIn<PreprocessedEmbeddingRequest>,
ManyOut<Annotated<EmbeddingsEngineOutput>>,
> for Backend
{
async fn generate(
&self,
request: SingleIn<PreprocessedEmbeddingRequest>,
next: ServerStreamingEngine<
PreprocessedEmbeddingRequest,
Annotated<EmbeddingsEngineOutput>,
>,
) -> Result<ManyOut<Annotated<EmbeddingsEngineOutput>>> {
// For embeddings, we mostly pass through since no detokenization is needed
// But we could add validation, logging, or other post-processing here
let response_stream = next.generate(request).await?;
// Could add embedding-specific post-processing here:
// - Validation of embedding dimensions
// - Normalization if requested
// - Usage statistics validation
Ok(response_stream)
}
}
// todo - add visible stop conditions // todo - add visible stop conditions
// visible_stop_ids: HashSet<TokenIdType>, // visible_stop_ids: HashSet<TokenIdType>,
// visible_stop_sequences: Vec<String>, // visible_stop_sequences: Vec<String>,
......
...@@ -20,8 +20,8 @@ use crate::{ ...@@ -20,8 +20,8 @@ use crate::{
backend::Backend, backend::Backend,
kv_router::{KvPushRouter, KvRouterConfig}, kv_router::{KvPushRouter, KvRouterConfig},
model_type::ModelType, model_type::ModelType,
preprocessor::{OpenAIPreprocessor, PreprocessedRequest}, preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, PreprocessedRequest},
protocols::common::llm_backend::LLMEngineOutput, protocols::common::llm_backend::{EmbeddingsEngineOutput, LLMEngineOutput},
protocols::openai::chat_completions::{ protocols::openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
}, },
...@@ -300,14 +300,42 @@ impl ModelWatcher { ...@@ -300,14 +300,42 @@ impl ModelWatcher {
.add_completions_model(&model_entry.name, engine)?; .add_completions_model(&model_entry.name, engine)?;
} }
ModelType::Embedding => { ModelType::Embedding => {
let push_router = PushRouter::< let Some(mut card) = card else {
NvCreateEmbeddingRequest, anyhow::bail!("Missing model deployment card for embedding model");
Annotated<NvCreateEmbeddingResponse>, };
>::from_client(client, Default::default())
// Download tokenizer files to local disk
let _cache_dir = Some(card.move_from_nats(self.drt.nats_client()).await?);
// Create preprocessing pipeline similar to Backend
let frontend = SegmentSource::<
SingleIn<NvCreateEmbeddingRequest>,
ManyOut<Annotated<NvCreateEmbeddingResponse>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let router = PushRouter::<
PreprocessedEmbeddingRequest,
Annotated<EmbeddingsEngineOutput>,
>::from_client(client, self.router_mode)
.await?; .await?;
let engine = Arc::new(push_router);
// Note: Embeddings don't need KV routing complexity
let service_backend = ServiceBackend::from_engine(Arc::new(router));
// Link the pipeline: frontend -> preprocessor -> backend -> service_backend -> backend -> preprocessor -> frontend
let embedding_engine = frontend
.link(preprocessor.forward_edge())?
.link(backend.forward_edge())?
.link(service_backend)?
.link(backend.backward_edge())?
.link(preprocessor.backward_edge())?
.link(frontend)?;
self.manager self.manager
.add_embeddings_model(&model_entry.name, engine)?; .add_embeddings_model(&model_entry.name, embedding_engine)?;
} }
} }
......
...@@ -27,6 +27,7 @@ pub mod prompt; ...@@ -27,6 +27,7 @@ pub mod prompt;
pub mod tools; pub mod tools;
use anyhow::Result; use anyhow::Result;
use async_openai::types::EncodingFormat;
use futures::stream::{self, StreamExt}; use futures::stream::{self, StreamExt};
use prompt::OAIPromptFormatter; use prompt::OAIPromptFormatter;
use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
...@@ -48,6 +49,7 @@ use crate::protocols::{ ...@@ -48,6 +49,7 @@ use crate::protocols::{
openai::{ openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
nvext::NvExtProvider, nvext::NvExtProvider,
DeltaGeneratorExt, DeltaGeneratorExt,
}, },
...@@ -57,6 +59,9 @@ use crate::tokenizers::{traits::Tokenizer, HuggingFaceTokenizer}; ...@@ -57,6 +59,9 @@ use crate::tokenizers::{traits::Tokenizer, HuggingFaceTokenizer};
use crate::preprocessor::prompt::{PromptFormatter, PromptInput, TextInput, TokenInput}; use crate::preprocessor::prompt::{PromptFormatter, PromptInput, TextInput, TokenInput};
pub use crate::protocols::common::llm_backend::{BackendOutput, PreprocessedRequest}; pub use crate::protocols::common::llm_backend::{BackendOutput, PreprocessedRequest};
pub use crate::protocols::common::preprocessor::PreprocessedEmbeddingRequest;
use crate::protocols::common::llm_backend::EmbeddingsEngineOutput;
pub const ANNOTATION_FORMATTED_PROMPT: &str = "formatted_prompt"; pub const ANNOTATION_FORMATTED_PROMPT: &str = "formatted_prompt";
pub const ANNOTATION_TOKEN_IDS: &str = "token_ids"; pub const ANNOTATION_TOKEN_IDS: &str = "token_ids";
...@@ -265,6 +270,68 @@ impl OpenAIPreprocessor { ...@@ -265,6 +270,68 @@ impl OpenAIPreprocessor {
Ok((builder.build()?, annotations)) Ok((builder.build()?, annotations))
} }
/// Preprocess an embedding request, handling both text and token ID inputs.
///
/// For text inputs, tokenizes the text using the configured tokenizer.
/// For token ID inputs, uses the provided token IDs directly and skips tokenization.
///
/// Returns both the preprocessed request and a hashmap of annotations.
pub async fn preprocess_embedding_request(
&self,
request: &NvCreateEmbeddingRequest,
) -> Result<(PreprocessedEmbeddingRequest, HashMap<String, String>)> {
let mut annotations = HashMap::new();
let mut builder = PreprocessedEmbeddingRequest::builder();
let all_token_ids = match &request.inner.input {
async_openai::types::EmbeddingInput::String(s) => {
let encoding = tokio::task::block_in_place(|| self.tokenizer.encode(s))?;
vec![encoding.token_ids]
}
async_openai::types::EmbeddingInput::StringArray(arr) => {
let input_strs: Vec<String> = arr.to_vec();
let encodings = tokio::task::spawn_blocking({
let tokenizer = self.tokenizer.clone();
let strs = input_strs.clone();
move || {
tokenizer.encode_batch(&strs.iter().map(|s| s.as_str()).collect::<Vec<_>>())
}
})
.await??;
let token_arrays: Vec<Vec<u32>> = encodings
.into_iter()
.map(|encoding| encoding.token_ids)
.collect();
token_arrays
}
async_openai::types::EmbeddingInput::IntegerArray(token_ids) => vec![token_ids.clone()],
async_openai::types::EmbeddingInput::ArrayOfIntegerArray(token_arrays) => {
token_arrays.clone()
}
};
// Handle annotations
if request.has_annotation(ANNOTATION_TOKEN_IDS) {
annotations.insert(
ANNOTATION_TOKEN_IDS.to_string(),
serde_json::to_string(&all_token_ids)?,
);
}
builder.token_ids(all_token_ids);
builder.model(request.inner.model.clone());
builder.encoding_format(request.inner.encoding_format.as_ref().map(|f| match f {
EncodingFormat::Float => "float".to_string(),
EncodingFormat::Base64 => "base64".to_string(),
}));
builder.dimensions(request.inner.dimensions);
builder.annotations(request.annotations().unwrap_or_default());
builder.mdc_sum(Some(self.mdcsum.clone()));
Ok((builder.build()?, annotations))
}
pub fn transform_postprocessor_stream<Resp: Send + Sync + 'static + std::fmt::Debug>( pub fn transform_postprocessor_stream<Resp: Send + Sync + 'static + std::fmt::Debug>(
stream: ManyOut<Annotated<BackendOutput>>, stream: ManyOut<Annotated<BackendOutput>>,
generator: Box<dyn DeltaGeneratorExt<Resp>>, generator: Box<dyn DeltaGeneratorExt<Resp>>,
...@@ -367,6 +434,46 @@ impl OpenAIPreprocessor { ...@@ -367,6 +434,46 @@ impl OpenAIPreprocessor {
ResponseStream::new(Box::pin(stream), context) ResponseStream::new(Box::pin(stream), context)
} }
/// Transform engine embedding output stream to OpenAI embedding response stream
pub fn transform_embedding_postprocessor_stream(
stream: ManyOut<Annotated<EmbeddingsEngineOutput>>,
original_request: NvCreateEmbeddingRequest,
) -> ManyOut<Annotated<NvCreateEmbeddingResponse>> {
let context = stream.context();
let transformed_stream = stream.map(move |output| {
output.map_data(|engine_output| {
// Convert engine output to OpenAI response format
let embeddings: Vec<async_openai::types::Embedding> = engine_output
.embeddings
.into_iter()
.enumerate()
.map(|(index, embedding)| async_openai::types::Embedding {
index: index as u32,
object: "embedding".to_string(),
embedding: embedding.into_iter().map(|f| f as f32).collect(),
})
.collect();
let response = NvCreateEmbeddingResponse {
inner: async_openai::types::CreateEmbeddingResponse {
object: "list".to_string(),
model: original_request.inner.model.clone(),
data: embeddings,
usage: async_openai::types::EmbeddingUsage {
prompt_tokens: engine_output.prompt_tokens,
total_tokens: engine_output.total_tokens,
},
},
};
Ok(response)
})
});
ResponseStream::new(Box::pin(transformed_stream), context)
}
} }
// for pals, we do not want to add the generation prompt to the formatted prompt // for pals, we do not want to add the generation prompt to the formatted prompt
...@@ -488,3 +595,51 @@ impl ...@@ -488,3 +595,51 @@ impl
Ok(ResponseStream::new(Box::pin(stream), context)) Ok(ResponseStream::new(Box::pin(stream), context))
} }
} }
#[async_trait]
impl
Operator<
SingleIn<NvCreateEmbeddingRequest>,
ManyOut<Annotated<NvCreateEmbeddingResponse>>,
SingleIn<PreprocessedEmbeddingRequest>,
ManyOut<Annotated<EmbeddingsEngineOutput>>,
> for OpenAIPreprocessor
{
async fn generate(
&self,
request: SingleIn<NvCreateEmbeddingRequest>,
next: Arc<
dyn AsyncEngine<
SingleIn<PreprocessedEmbeddingRequest>,
ManyOut<Annotated<EmbeddingsEngineOutput>>,
Error,
>,
>,
) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> {
// Unpack request
let (request, context) = request.into_parts();
// Preprocess the embedding request
let (preprocessed_request, annotations) =
self.preprocess_embedding_request(&request).await?;
// Forward to next stage
let preprocessed_request = context.map(|_| preprocessed_request);
let response_stream = next.generate(preprocessed_request).await?;
// Transform response stream back to OpenAI format
let stream = Self::transform_embedding_postprocessor_stream(response_stream, request);
let context = stream.context();
// Prepend annotations
let annotations_stream = stream::iter(
annotations
.into_iter()
.flat_map(|(k, v)| Annotated::from_annotation(k, &v))
.collect::<Vec<_>>(),
);
let combined_stream = annotations_stream.chain(stream);
Ok(ResponseStream::new(Box::pin(combined_stream), context))
}
}
...@@ -133,3 +133,14 @@ impl LLMEngineOutput { ...@@ -133,3 +133,14 @@ impl LLMEngineOutput {
} }
} }
} }
/// Raw output from embedding engines containing embedding vectors
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct EmbeddingsEngineOutput {
/// Generated embedding vectors (one per input text)
pub embeddings: Vec<Vec<f64>>,
/// Token usage information
pub prompt_tokens: u32,
pub total_tokens: u32,
}
...@@ -68,3 +68,40 @@ impl PreprocessedRequest { ...@@ -68,3 +68,40 @@ impl PreprocessedRequest {
PreprocessedRequestBuilder::default() PreprocessedRequestBuilder::default()
} }
} }
/// [`PreprocessedEmbeddingRequest`] is the internal representation of an embedding request
/// after preprocessing. Contains tokenized input ready for embedding engines.
#[derive(Serialize, Deserialize, Debug, Clone, Builder)]
pub struct PreprocessedEmbeddingRequest {
/// Tokenized input text as token IDs (one Vec per input text)
pub token_ids: Vec<Vec<TokenIdType>>,
/// Model to use for embedding
pub model: String,
/// Encoding format preference
pub encoding_format: Option<String>,
/// Number of dimensions for output embeddings (if supported)
pub dimensions: Option<u32>,
/// The computed checksum of the Model Deployment Card (MDC)
#[builder(default)]
pub mdc_sum: Option<String>,
/// User requested annotations for the request
#[builder(default)]
pub annotations: Vec<String>,
}
impl PreprocessedEmbeddingRequest {
pub fn has_annotation(&self, annotation: &str) -> bool {
self.annotations.contains(&annotation.to_string())
}
}
impl PreprocessedEmbeddingRequest {
pub fn builder() -> PreprocessedEmbeddingRequestBuilder {
PreprocessedEmbeddingRequestBuilder::default()
}
}
...@@ -58,6 +58,7 @@ pub mod traits { ...@@ -58,6 +58,7 @@ pub mod traits {
pub trait Encoder: Send + Sync { pub trait Encoder: Send + Sync {
fn encode(&self, input: &str) -> Result<Encoding>; fn encode(&self, input: &str) -> Result<Encoding>;
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>>;
} }
pub trait Decoder: Send + Sync { pub trait Decoder: Send + Sync {
......
...@@ -54,6 +54,30 @@ impl Encoder for HuggingFaceTokenizer { ...@@ -54,6 +54,30 @@ impl Encoder for HuggingFaceTokenizer {
spans, spans,
}) })
} }
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
let hf_encodings = self
.tokenizer
.encode_batch(inputs.to_vec(), false)
.map_err(|err| Error::msg(format!("Error encoding input: {}", err)))?;
let encodings = hf_encodings
.into_iter()
.map(|encoding| {
let token_ids = encoding.get_ids().to_vec();
let tokens = encoding.get_tokens().to_vec();
let spans = encoding.get_offsets().to_vec();
Encoding {
token_ids,
tokens,
spans,
}
})
.collect();
Ok(encodings)
}
} }
impl Decoder for HuggingFaceTokenizer { impl Decoder for HuggingFaceTokenizer {
......
...@@ -73,6 +73,17 @@ impl Encoder for SentencePieceTokenizer { ...@@ -73,6 +73,17 @@ impl Encoder for SentencePieceTokenizer {
spans, spans,
}) })
} }
/// Encodes multiple string inputs into tokens using the SentencePiece model.
///
/// # Arguments
/// * `inputs` - The texts to encode
///
/// # Returns
/// * `Result<Vec<Encoding>>` - The encoded tokens for each input
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
inputs.iter().map(|input| self.encode(input)).collect()
}
} }
impl Decoder for SentencePieceTokenizer { impl Decoder for SentencePieceTokenizer {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment