Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
e83009a6
"vscode:/vscode.git/clone" did not exist on "8d9ccdfb8da596aaf89f7314d75c1752f0405703"
Unverified
Commit
e83009a6
authored
Jun 04, 2025
by
Tom O'Brien
Committed by
GitHub
Jun 04, 2025
Browse files
feat: add implementation for embeddings (#1290)
parent
5e9370d3
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
350 additions
and
13 deletions
+350
-13
launch/dynamo-run/src/input/http.rs
launch/dynamo-run/src/input/http.rs
+1
-0
launch/dynamo-run/src/subprocess/sglang_inc.py
launch/dynamo-run/src/subprocess/sglang_inc.py
+46
-4
lib/llm/src/discovery/model_manager.rs
lib/llm/src/discovery/model_manager.rs
+1
-3
lib/llm/src/http/service/openai.rs
lib/llm/src/http/service/openai.rs
+53
-4
lib/llm/src/http/service/service_v2.rs
lib/llm/src/http/service/service_v2.rs
+1
-1
lib/llm/src/protocols/openai/embeddings.rs
lib/llm/src/protocols/openai/embeddings.rs
+4
-1
lib/llm/src/protocols/openai/embeddings/aggregator.rs
lib/llm/src/protocols/openai/embeddings/aggregator.rs
+244
-0
No files found.
launch/dynamo-run/src/input/http.rs
View file @
e83009a6
...
@@ -33,6 +33,7 @@ pub async fn run(
...
@@ -33,6 +33,7 @@ pub async fn run(
.port
(
flags
.http_port
)
.port
(
flags
.http_port
)
.enable_chat_endpoints
(
true
)
.enable_chat_endpoints
(
true
)
.enable_cmpl_endpoints
(
true
)
.enable_cmpl_endpoints
(
true
)
.enable_embeddings_endpoints
(
true
)
.with_request_template
(
template
)
.with_request_template
(
template
)
.build
()
?
;
.build
()
?
;
match
engine_config
{
match
engine_config
{
...
...
launch/dynamo-run/src/subprocess/sglang_inc.py
View file @
e83009a6
...
@@ -77,6 +77,42 @@ class RequestHandler:
...
@@ -77,6 +77,42 @@ class RequestHandler:
num_output_tokens_so_far
=
next_total_toks
num_output_tokens_so_far
=
next_total_toks
class
EmbeddingRequestHandler
(
RequestHandler
):
"""
Request handler for the embedding endpoint
"""
def
__init__
(
self
,
engine
:
sglang
.
Engine
,
model_name
:
str
):
super
().
__init__
(
engine
)
self
.
_model_name
=
model_name
async
def
generate
(
self
,
request
):
gen
=
await
self
.
engine_client
.
async_encode
(
prompt
=
request
[
"input"
])
tokens
=
0
embeddings
=
[]
for
idx
,
res
in
enumerate
(
gen
):
embeddings
.
append
(
{
"index"
:
idx
,
"object"
:
"embedding"
,
"embedding"
:
res
[
"embedding"
],
}
)
tokens
+=
res
[
"meta_info"
][
"prompt_tokens"
]
out
=
{
"object"
:
"list"
,
"model"
:
self
.
_model_name
,
"data"
:
embeddings
,
"usage"
:
{
"prompt_tokens"
:
tokens
,
"total_tokens"
:
tokens
,
},
}
yield
out
@
dynamo_worker
(
static
=
False
)
@
dynamo_worker
(
static
=
False
)
async
def
worker
(
runtime
:
DistributedRuntime
):
async
def
worker
(
runtime
:
DistributedRuntime
):
await
init
(
runtime
,
cmd_line_args
())
await
init
(
runtime
,
cmd_line_args
())
...
@@ -129,13 +165,20 @@ async def init(runtime: DistributedRuntime, config: Config):
...
@@ -129,13 +165,20 @@ async def init(runtime: DistributedRuntime, config: Config):
await
component
.
create_service
()
await
component
.
create_service
()
endpoint
=
component
.
endpoint
(
config
.
endpoint
)
endpoint
=
component
.
endpoint
(
config
.
endpoint
)
await
register_llm
(
model_type
=
(
ModelType
.
Backend
,
endpoint
,
config
.
model_path
,
config
.
model_name
ModelType
.
Backend
if
not
engine_args
.
is_embedding
else
ModelType
.
Embedding
)
)
await
register_llm
(
model_type
,
endpoint
,
config
.
model_path
,
config
.
model_name
)
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
# after the lease is revoked
await
endpoint
.
serve_endpoint
(
RequestHandler
(
engine_client
).
generate
)
await
endpoint
.
serve_endpoint
(
RequestHandler
(
engine_client
).
generate
if
not
engine_args
.
is_embedding
else
EmbeddingRequestHandler
(
engine_client
,
model_name
=
config
.
model_name
or
config
.
model_path
).
generate
)
def
cmd_line_args
():
def
cmd_line_args
():
...
@@ -230,7 +273,6 @@ def cmd_line_args():
...
@@ -230,7 +273,6 @@ def cmd_line_args():
config
.
node_rank
=
args
.
node_rank
config
.
node_rank
=
args
.
node_rank
config
.
dist_init_addr
=
args
.
dist_init_addr
config
.
dist_init_addr
=
args
.
dist_init_addr
config
.
extra_engine_args
=
args
.
extra_engine_args
config
.
extra_engine_args
=
args
.
extra_engine_args
return
config
return
config
...
...
lib/llm/src/discovery/model_manager.rs
View file @
e83009a6
...
@@ -129,9 +129,7 @@ impl ModelManager {
...
@@ -129,9 +129,7 @@ impl ModelManager {
clients
.remove
(
model
)
clients
.remove
(
model
)
}
}
// TODO: Remove this allow once `embeddings` is implemented in lib/llm/src/http/service/openai.rs
pub
fn
get_embeddings_engine
(
#[allow(dead_code)]
fn
get_embeddings_engine
(
&
self
,
&
self
,
model
:
&
str
,
model
:
&
str
,
)
->
Result
<
OpenAIEmbeddingsStreamingEngine
,
ModelManagerError
>
{
)
->
Result
<
OpenAIEmbeddingsStreamingEngine
,
ModelManagerError
>
{
...
...
lib/llm/src/http/service/openai.rs
View file @
e83009a6
...
@@ -27,7 +27,7 @@ use super::{
...
@@ -27,7 +27,7 @@ use super::{
service_v2
,
RouteDoc
,
service_v2
,
RouteDoc
,
};
};
use
crate
::
protocols
::
openai
::
embeddings
::
NvCreateEmbeddingRequest
;
use
crate
::
protocols
::
openai
::
embeddings
::
{
NvCreateEmbeddingRequest
,
NvCreateEmbeddingResponse
}
;
use
crate
::
protocols
::
openai
::{
use
crate
::
protocols
::
openai
::{
chat_completions
::
NvCreateChatCompletionResponse
,
completions
::
CompletionResponse
,
chat_completions
::
NvCreateChatCompletionResponse
,
completions
::
CompletionResponse
,
};
};
...
@@ -208,10 +208,59 @@ async fn completions(
...
@@ -208,10 +208,59 @@ async fn completions(
#[tracing::instrument(skip_all)]
#[tracing::instrument(skip_all)]
async
fn
embeddings
(
async
fn
embeddings
(
State
(
_
state
):
State
<
Arc
<
service_v2
::
State
>>
,
State
(
state
):
State
<
Arc
<
service_v2
::
State
>>
,
Json
(
_
request
):
Json
<
NvCreateEmbeddingRequest
>
,
Json
(
request
):
Json
<
NvCreateEmbeddingRequest
>
,
)
->
Result
<
Response
,
(
StatusCode
,
Json
<
ErrorResponse
>
)
>
{
)
->
Result
<
Response
,
(
StatusCode
,
Json
<
ErrorResponse
>
)
>
{
unimplemented!
(
"embeddings are not supported yet"
);
// return a 503 if the service is not ready
check_ready
(
&
state
)
?
;
// todo - extract distributed tracing id and context id from headers
let
request_id
=
uuid
::
Uuid
::
new_v4
()
.to_string
();
// Embeddings are typically not streamed, so we default to non-streaming
let
streaming
=
false
;
// todo - make the protocols be optional for model name
// todo - when optional, if none, apply a default
let
model
=
&
request
.inner.model
;
// todo - error handling should be more robust
let
engine
=
state
.manager
()
.get_embeddings_engine
(
model
)
.map_err
(|
_
|
ErrorResponse
::
model_not_found
())
?
;
// this will increment the inflight gauge for the model
let
mut
inflight
=
state
.metrics_clone
()
.create_inflight_guard
(
model
,
Endpoint
::
Embeddings
,
streaming
);
// setup context
// todo - inherit request_id from distributed trace details
let
request
=
Context
::
with_id
(
request
,
request_id
.clone
());
// issue the generate call on the engine
let
stream
=
engine
.generate
(
request
)
.await
.map_err
(|
e
|
ErrorResponse
::
from_anyhow
(
e
,
"Failed to generate embeddings"
))
?
;
// Embeddings are typically returned as a single response (non-streaming)
// so we fold the stream into a single response
let
response
=
NvCreateEmbeddingResponse
::
from_annotated_stream
(
stream
.into
())
.await
.map_err
(|
e
|
{
tracing
::
error!
(
"Failed to fold embeddings stream for {}: {:?}"
,
request_id
,
e
);
ErrorResponse
::
internal_server_error
(
"Failed to fold embeddings stream"
)
})
?
;
inflight
.mark_ok
();
Ok
(
Json
(
response
)
.into_response
())
}
}
/// OpenAI Chat Completions Request Handler
/// OpenAI Chat Completions Request Handler
...
...
lib/llm/src/http/service/service_v2.rs
View file @
e83009a6
...
@@ -75,7 +75,7 @@ pub struct HttpServiceConfig {
...
@@ -75,7 +75,7 @@ pub struct HttpServiceConfig {
#[builder(default
=
"true"
)]
#[builder(default
=
"true"
)]
enable_cmpl_endpoints
:
bool
,
enable_cmpl_endpoints
:
bool
,
#[builder(default
=
"
fals
e"
)]
#[builder(default
=
"
tru
e"
)]
enable_embeddings_endpoints
:
bool
,
enable_embeddings_endpoints
:
bool
,
#[builder(default
=
"None"
)]
#[builder(default
=
"None"
)]
...
...
lib/llm/src/protocols/openai/embeddings.rs
View file @
e83009a6
...
@@ -16,9 +16,12 @@
...
@@ -16,9 +16,12 @@
use
serde
::{
Deserialize
,
Serialize
};
use
serde
::{
Deserialize
,
Serialize
};
use
validator
::
Validate
;
use
validator
::
Validate
;
mod
aggregator
;
mod
nvext
;
mod
nvext
;
pub
use
nvext
::{
NvExt
,
NvExtProvider
};
pub
use
nvext
::{
NvExt
,
NvExtProvider
};
// pub use delta::DeltaGenerator;
pub
use
aggregator
::
DeltaAggregator
;
use
dynamo_runtime
::
protocols
::
annotated
::
AnnotationsProvider
;
use
dynamo_runtime
::
protocols
::
annotated
::
AnnotationsProvider
;
...
@@ -59,7 +62,7 @@ impl NvCreateEmbeddingResponse {
...
@@ -59,7 +62,7 @@ impl NvCreateEmbeddingResponse {
}
}
}
}
/// Implements `NvExtProvider` for `NvCr
eateEmbeddingRequest`,
/// Implements `NvExtProvider` for `NvCreateEmbeddingRequest`,
/// providing access to NVIDIA-specific extensions.
/// providing access to NVIDIA-specific extensions.
impl
NvExtProvider
for
NvCreateEmbeddingRequest
{
impl
NvExtProvider
for
NvCreateEmbeddingRequest
{
/// Returns a reference to the optional `NvExt` extension, if available.
/// Returns a reference to the optional `NvExt` extension, if available.
...
...
lib/llm/src/protocols/openai/embeddings/aggregator.rs
0 → 100644
View file @
e83009a6
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
super
::
NvCreateEmbeddingResponse
;
use
crate
::
protocols
::{
codec
::{
Message
,
SseCodecError
},
convert_sse_stream
,
Annotated
,
};
use
futures
::{
Stream
,
StreamExt
};
use
std
::
pin
::
Pin
;
/// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`.
type
DataStream
<
T
>
=
Pin
<
Box
<
dyn
Stream
<
Item
=
T
>
+
Send
+
Sync
>>
;
/// Aggregates a stream of [`NvCreateEmbeddingResponse`]s into a single
/// [`NvCreateEmbeddingResponse`]. For embeddings, this is typically simpler
/// than text generation as embeddings are usually returned as a complete response.
pub
struct
DeltaAggregator
{
/// The accumulated embeddings response.
response
:
Option
<
NvCreateEmbeddingResponse
>
,
/// Optional error message if an error occurs during aggregation.
error
:
Option
<
String
>
,
}
impl
Default
for
DeltaAggregator
{
/// Provides a default implementation for `DeltaAggregator` by calling [`DeltaAggregator::new`].
fn
default
()
->
Self
{
Self
::
new
()
}
}
impl
DeltaAggregator
{
/// Creates a new, empty [`DeltaAggregator`] instance.
pub
fn
new
()
->
Self
{
Self
{
response
:
None
,
error
:
None
,
}
}
/// Aggregates a stream of [`NvCreateEmbeddingResponse`]s into a single
/// [`NvCreateEmbeddingResponse`].
///
/// # Arguments
/// * `stream` - A stream of annotated embedding responses.
///
/// # Returns
/// * `Ok(NvCreateEmbeddingResponse)` if aggregation is successful.
/// * `Err(String)` if an error occurs during processing.
pub
async
fn
apply
(
stream
:
DataStream
<
Annotated
<
NvCreateEmbeddingResponse
>>
,
)
->
Result
<
NvCreateEmbeddingResponse
,
String
>
{
let
aggregator
=
stream
.fold
(
DeltaAggregator
::
new
(),
|
mut
aggregator
,
delta
|
async
move
{
// Attempt to unwrap the delta, capturing any errors.
let
delta
=
match
delta
.ok
()
{
Ok
(
delta
)
=>
delta
,
Err
(
error
)
=>
{
aggregator
.error
=
Some
(
error
);
return
aggregator
;
}
};
if
aggregator
.error
.is_none
()
{
if
let
Some
(
response
)
=
delta
.data
{
// For embeddings, we typically expect a single complete response
// or we accumulate data from multiple responses
match
&
mut
aggregator
.response
{
Some
(
existing
)
=>
{
// Merge embedding data if we have multiple responses
existing
.inner.data
.extend
(
response
.inner.data
);
// Update usage statistics
existing
.inner.usage.prompt_tokens
+=
response
.inner.usage.prompt_tokens
;
existing
.inner.usage.total_tokens
+=
response
.inner.usage.total_tokens
;
}
None
=>
{
aggregator
.response
=
Some
(
response
);
}
}
}
}
aggregator
})
.await
;
// Return early if an error was encountered.
if
let
Some
(
error
)
=
aggregator
.error
{
return
Err
(
error
);
}
// Return the aggregated response or an empty response if none was found.
Ok
(
aggregator
.response
.unwrap_or_else
(
NvCreateEmbeddingResponse
::
empty
))
}
}
impl
NvCreateEmbeddingResponse
{
/// Converts an SSE stream into a [`NvCreateEmbeddingResponse`].
///
/// # Arguments
/// * `stream` - A stream of SSE messages containing embedding responses.
///
/// # Returns
/// * `Ok(NvCreateEmbeddingResponse)` if aggregation succeeds.
/// * `Err(String)` if an error occurs.
pub
async
fn
from_sse_stream
(
stream
:
DataStream
<
Result
<
Message
,
SseCodecError
>>
,
)
->
Result
<
NvCreateEmbeddingResponse
,
String
>
{
let
stream
=
convert_sse_stream
::
<
NvCreateEmbeddingResponse
>
(
stream
);
NvCreateEmbeddingResponse
::
from_annotated_stream
(
stream
)
.await
}
/// Aggregates an annotated stream of embedding responses into a final response.
///
/// # Arguments
/// * `stream` - A stream of annotated embedding responses.
///
/// # Returns
/// * `Ok(NvCreateEmbeddingResponse)` if aggregation succeeds.
/// * `Err(String)` if an error occurs.
pub
async
fn
from_annotated_stream
(
stream
:
DataStream
<
Annotated
<
NvCreateEmbeddingResponse
>>
,
)
->
Result
<
NvCreateEmbeddingResponse
,
String
>
{
DeltaAggregator
::
apply
(
stream
)
.await
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
futures
::
stream
;
fn
create_test_embedding_response
(
embeddings
:
Vec
<
async_openai
::
types
::
Embedding
>
,
prompt_tokens
:
u32
,
total_tokens
:
u32
,
)
->
Annotated
<
NvCreateEmbeddingResponse
>
{
let
response
=
NvCreateEmbeddingResponse
{
inner
:
async_openai
::
types
::
CreateEmbeddingResponse
{
object
:
"list"
.to_string
(),
model
:
"test-model"
.to_string
(),
data
:
embeddings
,
usage
:
async_openai
::
types
::
EmbeddingUsage
{
prompt_tokens
,
total_tokens
,
},
},
};
Annotated
::
from_data
(
response
)
}
#[tokio::test]
async
fn
test_empty_stream
()
{
let
stream
=
stream
::
empty
();
let
result
=
DeltaAggregator
::
apply
(
Box
::
pin
(
stream
))
.await
;
assert
!
(
result
.is_ok
());
let
response
=
result
.unwrap
();
assert_eq!
(
response
.inner.data
.len
(),
0
);
assert_eq!
(
response
.inner.object
,
"list"
);
assert_eq!
(
response
.inner.model
,
"embedding"
);
}
#[tokio::test]
async
fn
test_single_embedding
()
{
let
embedding
=
async_openai
::
types
::
Embedding
{
index
:
0
,
object
:
"embedding"
.to_string
(),
embedding
:
vec!
[
0.1
,
0.2
,
0.3
],
};
let
annotated
=
create_test_embedding_response
(
vec!
[
embedding
.clone
()],
10
,
10
);
let
stream
=
stream
::
iter
(
vec!
[
annotated
]);
let
result
=
DeltaAggregator
::
apply
(
Box
::
pin
(
stream
))
.await
;
assert
!
(
result
.is_ok
());
let
response
=
result
.unwrap
();
assert_eq!
(
response
.inner.data
.len
(),
1
);
assert_eq!
(
response
.inner.data
[
0
]
.index
,
0
);
assert_eq!
(
response
.inner.data
[
0
]
.embedding
,
vec!
[
0.1
,
0.2
,
0.3
]);
assert_eq!
(
response
.inner.usage.prompt_tokens
,
10
);
assert_eq!
(
response
.inner.usage.total_tokens
,
10
);
}
#[tokio::test]
async
fn
test_multiple_embeddings
()
{
let
embedding1
=
async_openai
::
types
::
Embedding
{
index
:
0
,
object
:
"embedding"
.to_string
(),
embedding
:
vec!
[
0.1
,
0.2
,
0.3
],
};
let
embedding2
=
async_openai
::
types
::
Embedding
{
index
:
1
,
object
:
"embedding"
.to_string
(),
embedding
:
vec!
[
0.4
,
0.5
,
0.6
],
};
let
annotated1
=
create_test_embedding_response
(
vec!
[
embedding1
.clone
()],
5
,
5
);
let
annotated2
=
create_test_embedding_response
(
vec!
[
embedding2
.clone
()],
7
,
7
);
let
stream
=
stream
::
iter
(
vec!
[
annotated1
,
annotated2
]);
let
result
=
DeltaAggregator
::
apply
(
Box
::
pin
(
stream
))
.await
;
assert
!
(
result
.is_ok
());
let
response
=
result
.unwrap
();
assert_eq!
(
response
.inner.data
.len
(),
2
);
assert_eq!
(
response
.inner.data
[
0
]
.index
,
0
);
assert_eq!
(
response
.inner.data
[
1
]
.index
,
1
);
assert_eq!
(
response
.inner.usage.prompt_tokens
,
12
);
// sum of 5 and 7
assert_eq!
(
response
.inner.usage.total_tokens
,
12
);
// sum of 5 and 7
}
#[tokio::test]
async
fn
test_error_in_stream
()
{
let
error_annotated
=
Annotated
::
<
NvCreateEmbeddingResponse
>
::
from_error
(
"Test error"
.to_string
());
let
stream
=
stream
::
iter
(
vec!
[
error_annotated
]);
let
result
=
DeltaAggregator
::
apply
(
Box
::
pin
(
stream
))
.await
;
assert
!
(
result
.is_err
());
assert
!
(
result
.unwrap_err
()
.contains
(
"Test error"
));
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment