Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
e83009a6
Unverified
Commit
e83009a6
authored
Jun 04, 2025
by
Tom O'Brien
Committed by
GitHub
Jun 04, 2025
Browse files
feat: add implementation for embeddings (#1290)
parent
5e9370d3
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
350 additions
and
13 deletions
+350
-13
launch/dynamo-run/src/input/http.rs
launch/dynamo-run/src/input/http.rs
+1
-0
launch/dynamo-run/src/subprocess/sglang_inc.py
launch/dynamo-run/src/subprocess/sglang_inc.py
+46
-4
lib/llm/src/discovery/model_manager.rs
lib/llm/src/discovery/model_manager.rs
+1
-3
lib/llm/src/http/service/openai.rs
lib/llm/src/http/service/openai.rs
+53
-4
lib/llm/src/http/service/service_v2.rs
lib/llm/src/http/service/service_v2.rs
+1
-1
lib/llm/src/protocols/openai/embeddings.rs
lib/llm/src/protocols/openai/embeddings.rs
+4
-1
lib/llm/src/protocols/openai/embeddings/aggregator.rs
lib/llm/src/protocols/openai/embeddings/aggregator.rs
+244
-0
No files found.
launch/dynamo-run/src/input/http.rs
View file @
e83009a6
...
...
@@ -33,6 +33,7 @@ pub async fn run(
.port
(
flags
.http_port
)
.enable_chat_endpoints
(
true
)
.enable_cmpl_endpoints
(
true
)
.enable_embeddings_endpoints
(
true
)
.with_request_template
(
template
)
.build
()
?
;
match
engine_config
{
...
...
launch/dynamo-run/src/subprocess/sglang_inc.py
View file @
e83009a6
...
...
@@ -77,6 +77,42 @@ class RequestHandler:
num_output_tokens_so_far
=
next_total_toks
class
EmbeddingRequestHandler
(
RequestHandler
):
"""
Request handler for the embedding endpoint
"""
def
__init__
(
self
,
engine
:
sglang
.
Engine
,
model_name
:
str
):
super
().
__init__
(
engine
)
self
.
_model_name
=
model_name
async
def
generate
(
self
,
request
):
gen
=
await
self
.
engine_client
.
async_encode
(
prompt
=
request
[
"input"
])
tokens
=
0
embeddings
=
[]
for
idx
,
res
in
enumerate
(
gen
):
embeddings
.
append
(
{
"index"
:
idx
,
"object"
:
"embedding"
,
"embedding"
:
res
[
"embedding"
],
}
)
tokens
+=
res
[
"meta_info"
][
"prompt_tokens"
]
out
=
{
"object"
:
"list"
,
"model"
:
self
.
_model_name
,
"data"
:
embeddings
,
"usage"
:
{
"prompt_tokens"
:
tokens
,
"total_tokens"
:
tokens
,
},
}
yield
out
@
dynamo_worker
(
static
=
False
)
async
def
worker
(
runtime
:
DistributedRuntime
):
await
init
(
runtime
,
cmd_line_args
())
...
...
@@ -129,13 +165,20 @@ async def init(runtime: DistributedRuntime, config: Config):
await
component
.
create_service
()
endpoint
=
component
.
endpoint
(
config
.
endpoint
)
await
register_llm
(
ModelType
.
Backend
,
endpoint
,
config
.
model_path
,
config
.
model_name
model_type
=
(
ModelType
.
Backend
if
not
engine_args
.
is_embedding
else
ModelType
.
Embedding
)
await
register_llm
(
model_type
,
endpoint
,
config
.
model_path
,
config
.
model_name
)
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
await
endpoint
.
serve_endpoint
(
RequestHandler
(
engine_client
).
generate
)
await
endpoint
.
serve_endpoint
(
RequestHandler
(
engine_client
).
generate
if
not
engine_args
.
is_embedding
else
EmbeddingRequestHandler
(
engine_client
,
model_name
=
config
.
model_name
or
config
.
model_path
).
generate
)
def
cmd_line_args
():
...
...
@@ -230,7 +273,6 @@ def cmd_line_args():
config
.
node_rank
=
args
.
node_rank
config
.
dist_init_addr
=
args
.
dist_init_addr
config
.
extra_engine_args
=
args
.
extra_engine_args
return
config
...
...
lib/llm/src/discovery/model_manager.rs
View file @
e83009a6
...
...
@@ -129,9 +129,7 @@ impl ModelManager {
clients
.remove
(
model
)
}
// TODO: Remove this allow once `embeddings` is implemented in lib/llm/src/http/service/openai.rs
#[allow(dead_code)]
fn
get_embeddings_engine
(
pub
fn
get_embeddings_engine
(
&
self
,
model
:
&
str
,
)
->
Result
<
OpenAIEmbeddingsStreamingEngine
,
ModelManagerError
>
{
...
...
lib/llm/src/http/service/openai.rs
View file @
e83009a6
...
...
@@ -27,7 +27,7 @@ use super::{
service_v2
,
RouteDoc
,
};
use
crate
::
protocols
::
openai
::
embeddings
::
NvCreateEmbeddingRequest
;
use
crate
::
protocols
::
openai
::
embeddings
::
{
NvCreateEmbeddingRequest
,
NvCreateEmbeddingResponse
}
;
use
crate
::
protocols
::
openai
::{
chat_completions
::
NvCreateChatCompletionResponse
,
completions
::
CompletionResponse
,
};
...
...
@@ -208,10 +208,59 @@ async fn completions(
#[tracing::instrument(skip_all)]
async
fn
embeddings
(
State
(
_
state
):
State
<
Arc
<
service_v2
::
State
>>
,
Json
(
_
request
):
Json
<
NvCreateEmbeddingRequest
>
,
State
(
state
):
State
<
Arc
<
service_v2
::
State
>>
,
Json
(
request
):
Json
<
NvCreateEmbeddingRequest
>
,
)
->
Result
<
Response
,
(
StatusCode
,
Json
<
ErrorResponse
>
)
>
{
unimplemented!
(
"embeddings are not supported yet"
);
// return a 503 if the service is not ready
check_ready
(
&
state
)
?
;
// todo - extract distributed tracing id and context id from headers
let
request_id
=
uuid
::
Uuid
::
new_v4
()
.to_string
();
// Embeddings are typically not streamed, so we default to non-streaming
let
streaming
=
false
;
// todo - make the protocols be optional for model name
// todo - when optional, if none, apply a default
let
model
=
&
request
.inner.model
;
// todo - error handling should be more robust
let
engine
=
state
.manager
()
.get_embeddings_engine
(
model
)
.map_err
(|
_
|
ErrorResponse
::
model_not_found
())
?
;
// this will increment the inflight gauge for the model
let
mut
inflight
=
state
.metrics_clone
()
.create_inflight_guard
(
model
,
Endpoint
::
Embeddings
,
streaming
);
// setup context
// todo - inherit request_id from distributed trace details
let
request
=
Context
::
with_id
(
request
,
request_id
.clone
());
// issue the generate call on the engine
let
stream
=
engine
.generate
(
request
)
.await
.map_err
(|
e
|
ErrorResponse
::
from_anyhow
(
e
,
"Failed to generate embeddings"
))
?
;
// Embeddings are typically returned as a single response (non-streaming)
// so we fold the stream into a single response
let
response
=
NvCreateEmbeddingResponse
::
from_annotated_stream
(
stream
.into
())
.await
.map_err
(|
e
|
{
tracing
::
error!
(
"Failed to fold embeddings stream for {}: {:?}"
,
request_id
,
e
);
ErrorResponse
::
internal_server_error
(
"Failed to fold embeddings stream"
)
})
?
;
inflight
.mark_ok
();
Ok
(
Json
(
response
)
.into_response
())
}
/// OpenAI Chat Completions Request Handler
...
...
lib/llm/src/http/service/service_v2.rs
View file @
e83009a6
...
...
@@ -75,7 +75,7 @@ pub struct HttpServiceConfig {
#[builder(default
=
"true"
)]
enable_cmpl_endpoints
:
bool
,
#[builder(default
=
"
fals
e"
)]
#[builder(default
=
"
tru
e"
)]
enable_embeddings_endpoints
:
bool
,
#[builder(default
=
"None"
)]
...
...
lib/llm/src/protocols/openai/embeddings.rs
View file @
e83009a6
...
...
@@ -16,9 +16,12 @@
use
serde
::{
Deserialize
,
Serialize
};
use
validator
::
Validate
;
mod
aggregator
;
mod
nvext
;
pub
use
nvext
::{
NvExt
,
NvExtProvider
};
// pub use delta::DeltaGenerator;
pub
use
aggregator
::
DeltaAggregator
;
use
dynamo_runtime
::
protocols
::
annotated
::
AnnotationsProvider
;
...
...
@@ -59,7 +62,7 @@ impl NvCreateEmbeddingResponse {
}
}
/// Implements `NvExtProvider` for `NvCr
eateEmbeddingRequest`,
/// Implements `NvExtProvider` for `NvCreateEmbeddingRequest`,
/// providing access to NVIDIA-specific extensions.
impl
NvExtProvider
for
NvCreateEmbeddingRequest
{
/// Returns a reference to the optional `NvExt` extension, if available.
...
...
lib/llm/src/protocols/openai/embeddings/aggregator.rs
0 → 100644
View file @
e83009a6
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
super
::
NvCreateEmbeddingResponse
;
use
crate
::
protocols
::{
codec
::{
Message
,
SseCodecError
},
convert_sse_stream
,
Annotated
,
};
use
futures
::{
Stream
,
StreamExt
};
use
std
::
pin
::
Pin
;
/// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`.
type
DataStream
<
T
>
=
Pin
<
Box
<
dyn
Stream
<
Item
=
T
>
+
Send
+
Sync
>>
;
/// Aggregates a stream of [`NvCreateEmbeddingResponse`]s into a single
/// [`NvCreateEmbeddingResponse`]. For embeddings, this is typically simpler
/// than text generation as embeddings are usually returned as a complete response.
pub
struct
DeltaAggregator
{
/// The accumulated embeddings response.
response
:
Option
<
NvCreateEmbeddingResponse
>
,
/// Optional error message if an error occurs during aggregation.
error
:
Option
<
String
>
,
}
impl
Default
for
DeltaAggregator
{
/// Provides a default implementation for `DeltaAggregator` by calling [`DeltaAggregator::new`].
fn
default
()
->
Self
{
Self
::
new
()
}
}
impl
DeltaAggregator
{
/// Creates a new, empty [`DeltaAggregator`] instance.
pub
fn
new
()
->
Self
{
Self
{
response
:
None
,
error
:
None
,
}
}
/// Aggregates a stream of [`NvCreateEmbeddingResponse`]s into a single
/// [`NvCreateEmbeddingResponse`].
///
/// # Arguments
/// * `stream` - A stream of annotated embedding responses.
///
/// # Returns
/// * `Ok(NvCreateEmbeddingResponse)` if aggregation is successful.
/// * `Err(String)` if an error occurs during processing.
pub
async
fn
apply
(
stream
:
DataStream
<
Annotated
<
NvCreateEmbeddingResponse
>>
,
)
->
Result
<
NvCreateEmbeddingResponse
,
String
>
{
let
aggregator
=
stream
.fold
(
DeltaAggregator
::
new
(),
|
mut
aggregator
,
delta
|
async
move
{
// Attempt to unwrap the delta, capturing any errors.
let
delta
=
match
delta
.ok
()
{
Ok
(
delta
)
=>
delta
,
Err
(
error
)
=>
{
aggregator
.error
=
Some
(
error
);
return
aggregator
;
}
};
if
aggregator
.error
.is_none
()
{
if
let
Some
(
response
)
=
delta
.data
{
// For embeddings, we typically expect a single complete response
// or we accumulate data from multiple responses
match
&
mut
aggregator
.response
{
Some
(
existing
)
=>
{
// Merge embedding data if we have multiple responses
existing
.inner.data
.extend
(
response
.inner.data
);
// Update usage statistics
existing
.inner.usage.prompt_tokens
+=
response
.inner.usage.prompt_tokens
;
existing
.inner.usage.total_tokens
+=
response
.inner.usage.total_tokens
;
}
None
=>
{
aggregator
.response
=
Some
(
response
);
}
}
}
}
aggregator
})
.await
;
// Return early if an error was encountered.
if
let
Some
(
error
)
=
aggregator
.error
{
return
Err
(
error
);
}
// Return the aggregated response or an empty response if none was found.
Ok
(
aggregator
.response
.unwrap_or_else
(
NvCreateEmbeddingResponse
::
empty
))
}
}
impl
NvCreateEmbeddingResponse
{
/// Converts an SSE stream into a [`NvCreateEmbeddingResponse`].
///
/// # Arguments
/// * `stream` - A stream of SSE messages containing embedding responses.
///
/// # Returns
/// * `Ok(NvCreateEmbeddingResponse)` if aggregation succeeds.
/// * `Err(String)` if an error occurs.
pub
async
fn
from_sse_stream
(
stream
:
DataStream
<
Result
<
Message
,
SseCodecError
>>
,
)
->
Result
<
NvCreateEmbeddingResponse
,
String
>
{
let
stream
=
convert_sse_stream
::
<
NvCreateEmbeddingResponse
>
(
stream
);
NvCreateEmbeddingResponse
::
from_annotated_stream
(
stream
)
.await
}
/// Aggregates an annotated stream of embedding responses into a final response.
///
/// # Arguments
/// * `stream` - A stream of annotated embedding responses.
///
/// # Returns
/// * `Ok(NvCreateEmbeddingResponse)` if aggregation succeeds.
/// * `Err(String)` if an error occurs.
pub
async
fn
from_annotated_stream
(
stream
:
DataStream
<
Annotated
<
NvCreateEmbeddingResponse
>>
,
)
->
Result
<
NvCreateEmbeddingResponse
,
String
>
{
DeltaAggregator
::
apply
(
stream
)
.await
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
futures
::
stream
;
fn
create_test_embedding_response
(
embeddings
:
Vec
<
async_openai
::
types
::
Embedding
>
,
prompt_tokens
:
u32
,
total_tokens
:
u32
,
)
->
Annotated
<
NvCreateEmbeddingResponse
>
{
let
response
=
NvCreateEmbeddingResponse
{
inner
:
async_openai
::
types
::
CreateEmbeddingResponse
{
object
:
"list"
.to_string
(),
model
:
"test-model"
.to_string
(),
data
:
embeddings
,
usage
:
async_openai
::
types
::
EmbeddingUsage
{
prompt_tokens
,
total_tokens
,
},
},
};
Annotated
::
from_data
(
response
)
}
#[tokio::test]
async
fn
test_empty_stream
()
{
let
stream
=
stream
::
empty
();
let
result
=
DeltaAggregator
::
apply
(
Box
::
pin
(
stream
))
.await
;
assert
!
(
result
.is_ok
());
let
response
=
result
.unwrap
();
assert_eq!
(
response
.inner.data
.len
(),
0
);
assert_eq!
(
response
.inner.object
,
"list"
);
assert_eq!
(
response
.inner.model
,
"embedding"
);
}
#[tokio::test]
async
fn
test_single_embedding
()
{
let
embedding
=
async_openai
::
types
::
Embedding
{
index
:
0
,
object
:
"embedding"
.to_string
(),
embedding
:
vec!
[
0.1
,
0.2
,
0.3
],
};
let
annotated
=
create_test_embedding_response
(
vec!
[
embedding
.clone
()],
10
,
10
);
let
stream
=
stream
::
iter
(
vec!
[
annotated
]);
let
result
=
DeltaAggregator
::
apply
(
Box
::
pin
(
stream
))
.await
;
assert
!
(
result
.is_ok
());
let
response
=
result
.unwrap
();
assert_eq!
(
response
.inner.data
.len
(),
1
);
assert_eq!
(
response
.inner.data
[
0
]
.index
,
0
);
assert_eq!
(
response
.inner.data
[
0
]
.embedding
,
vec!
[
0.1
,
0.2
,
0.3
]);
assert_eq!
(
response
.inner.usage.prompt_tokens
,
10
);
assert_eq!
(
response
.inner.usage.total_tokens
,
10
);
}
#[tokio::test]
async
fn
test_multiple_embeddings
()
{
let
embedding1
=
async_openai
::
types
::
Embedding
{
index
:
0
,
object
:
"embedding"
.to_string
(),
embedding
:
vec!
[
0.1
,
0.2
,
0.3
],
};
let
embedding2
=
async_openai
::
types
::
Embedding
{
index
:
1
,
object
:
"embedding"
.to_string
(),
embedding
:
vec!
[
0.4
,
0.5
,
0.6
],
};
let
annotated1
=
create_test_embedding_response
(
vec!
[
embedding1
.clone
()],
5
,
5
);
let
annotated2
=
create_test_embedding_response
(
vec!
[
embedding2
.clone
()],
7
,
7
);
let
stream
=
stream
::
iter
(
vec!
[
annotated1
,
annotated2
]);
let
result
=
DeltaAggregator
::
apply
(
Box
::
pin
(
stream
))
.await
;
assert
!
(
result
.is_ok
());
let
response
=
result
.unwrap
();
assert_eq!
(
response
.inner.data
.len
(),
2
);
assert_eq!
(
response
.inner.data
[
0
]
.index
,
0
);
assert_eq!
(
response
.inner.data
[
1
]
.index
,
1
);
assert_eq!
(
response
.inner.usage.prompt_tokens
,
12
);
// sum of 5 and 7
assert_eq!
(
response
.inner.usage.total_tokens
,
12
);
// sum of 5 and 7
}
#[tokio::test]
async
fn
test_error_in_stream
()
{
let
error_annotated
=
Annotated
::
<
NvCreateEmbeddingResponse
>
::
from_error
(
"Test error"
.to_string
());
let
stream
=
stream
::
iter
(
vec!
[
error_annotated
]);
let
result
=
DeltaAggregator
::
apply
(
Box
::
pin
(
stream
))
.await
;
assert
!
(
result
.is_err
());
assert
!
(
result
.unwrap_err
()
.contains
(
"Test error"
));
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment