Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
e3f1bd5d
Unverified
Commit
e3f1bd5d
authored
Jun 26, 2025
by
Paul Hendricks
Committed by
GitHub
Jun 26, 2025
Browse files
refactor: refactored using CompletionResponse (#1658)
parent
7b7b6a6d
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
180 additions
and
159 deletions
+180
-159
launch/dynamo-run/src/input/common.rs
launch/dynamo-run/src/input/common.rs
+3
-2
launch/dynamo-run/src/input/http.rs
launch/dynamo-run/src/input/http.rs
+2
-2
lib/engines/mistralrs/src/lib.rs
lib/engines/mistralrs/src/lib.rs
+8
-4
lib/llm/src/discovery/watcher.rs
lib/llm/src/discovery/watcher.rs
+3
-3
lib/llm/src/engines.rs
lib/llm/src/engines.rs
+18
-10
lib/llm/src/http/service/openai.rs
lib/llm/src/http/service/openai.rs
+2
-2
lib/llm/src/preprocessor.rs
lib/llm/src/preprocessor.rs
+4
-4
lib/llm/src/protocols/openai/completions.rs
lib/llm/src/protocols/openai/completions.rs
+9
-39
lib/llm/src/protocols/openai/completions/aggregator.rs
lib/llm/src/protocols/openai/completions/aggregator.rs
+104
-76
lib/llm/src/protocols/openai/completions/delta.rs
lib/llm/src/protocols/openai/completions/delta.rs
+9
-7
lib/llm/src/types.rs
lib/llm/src/types.rs
+5
-3
lib/llm/tests/aggregators.rs
lib/llm/tests/aggregators.rs
+5
-3
lib/llm/tests/http-service.rs
lib/llm/tests/http-service.rs
+8
-4
No files found.
launch/dynamo-run/src/input/common.rs
View file @
e3f1bd5d
...
@@ -139,7 +139,7 @@ mod tests {
...
@@ -139,7 +139,7 @@ mod tests {
use
super
::
*
;
use
super
::
*
;
use
dynamo_llm
::
types
::
openai
::{
use
dynamo_llm
::
types
::
openai
::{
chat_completions
::{
NvCreateChatCompletionRequest
,
NvCreateChatCompletionStreamResponse
},
chat_completions
::{
NvCreateChatCompletionRequest
,
NvCreateChatCompletionStreamResponse
},
completions
::{
CompletionRe
sponse
,
NvCreateCompletionRe
quest
},
completions
::{
NvCreate
CompletionRe
quest
,
NvCreateCompletionRe
sponse
},
};
};
const
HF_PATH
:
&
str
=
concat!
(
const
HF_PATH
:
&
str
=
concat!
(
...
@@ -174,7 +174,8 @@ mod tests {
...
@@ -174,7 +174,8 @@ mod tests {
// Build pipeline for completions
// Build pipeline for completions
let
pipeline
=
let
pipeline
=
build_pipeline
::
<
NvCreateCompletionRequest
,
CompletionResponse
>
(
&
card
,
engine
)
.await
?
;
build_pipeline
::
<
NvCreateCompletionRequest
,
NvCreateCompletionResponse
>
(
&
card
,
engine
)
.await
?
;
// Verify pipeline was created
// Verify pipeline was created
assert
!
(
Arc
::
strong_count
(
&
pipeline
)
>=
1
);
assert
!
(
Arc
::
strong_count
(
&
pipeline
)
>=
1
);
...
...
launch/dynamo-run/src/input/http.rs
View file @
e3f1bd5d
...
@@ -15,7 +15,7 @@ use dynamo_llm::{
...
@@ -15,7 +15,7 @@ use dynamo_llm::{
openai
::
chat_completions
::{
openai
::
chat_completions
::{
NvCreateChatCompletionRequest
,
NvCreateChatCompletionStreamResponse
,
NvCreateChatCompletionRequest
,
NvCreateChatCompletionStreamResponse
,
},
},
openai
::
completions
::{
CompletionRe
sponse
,
NvCreateCompletionRe
quest
},
openai
::
completions
::{
NvCreate
CompletionRe
quest
,
NvCreateCompletionRe
sponse
},
},
},
};
};
use
dynamo_runtime
::
pipeline
::
RouterMode
;
use
dynamo_runtime
::
pipeline
::
RouterMode
;
...
@@ -78,7 +78,7 @@ pub async fn run(
...
@@ -78,7 +78,7 @@ pub async fn run(
let
cmpl_pipeline
=
common
::
build_pipeline
::
<
let
cmpl_pipeline
=
common
::
build_pipeline
::
<
NvCreateCompletionRequest
,
NvCreateCompletionRequest
,
CompletionResponse
,
NvCreate
CompletionResponse
,
>
(
model
.card
(),
inner_engine
)
>
(
model
.card
(),
inner_engine
)
.await
?
;
.await
?
;
manager
.add_completions_model
(
model
.service_name
(),
cmpl_pipeline
)
?
;
manager
.add_completions_model
(
model
.service_name
(),
cmpl_pipeline
)
?
;
...
...
lib/engines/mistralrs/src/lib.rs
View file @
e3f1bd5d
...
@@ -25,7 +25,7 @@ use dynamo_runtime::protocols::annotated::Annotated;
...
@@ -25,7 +25,7 @@ use dynamo_runtime::protocols::annotated::Annotated;
use
dynamo_llm
::
protocols
::
openai
::{
use
dynamo_llm
::
protocols
::
openai
::{
chat_completions
::{
NvCreateChatCompletionRequest
,
NvCreateChatCompletionStreamResponse
},
chat_completions
::{
NvCreateChatCompletionRequest
,
NvCreateChatCompletionStreamResponse
},
completions
::{
prompt_to_string
,
CompletionRe
sponse
,
NvCreateCompletionRe
quest
},
completions
::{
prompt_to_string
,
NvCreate
CompletionRe
quest
,
NvCreateCompletionRe
sponse
},
embeddings
::{
NvCreateEmbeddingRequest
,
NvCreateEmbeddingResponse
},
embeddings
::{
NvCreateEmbeddingRequest
,
NvCreateEmbeddingResponse
},
};
};
...
@@ -467,13 +467,17 @@ fn to_logit_bias(lb: HashMap<String, serde_json::Value>) -> HashMap<u32, f32> {
...
@@ -467,13 +467,17 @@ fn to_logit_bias(lb: HashMap<String, serde_json::Value>) -> HashMap<u32, f32> {
}
}
#[async_trait]
#[async_trait]
impl
AsyncEngine
<
SingleIn
<
NvCreateCompletionRequest
>
,
ManyOut
<
Annotated
<
CompletionResponse
>>
,
Error
>
impl
for
MistralRsEngine
AsyncEngine
<
SingleIn
<
NvCreateCompletionRequest
>
,
ManyOut
<
Annotated
<
NvCreateCompletionResponse
>>
,
Error
,
>
for
MistralRsEngine
{
{
async
fn
generate
(
async
fn
generate
(
&
self
,
&
self
,
request
:
SingleIn
<
NvCreateCompletionRequest
>
,
request
:
SingleIn
<
NvCreateCompletionRequest
>
,
)
->
Result
<
ManyOut
<
Annotated
<
CompletionResponse
>>
,
Error
>
{
)
->
Result
<
ManyOut
<
Annotated
<
NvCreate
CompletionResponse
>>
,
Error
>
{
let
(
request
,
context
)
=
request
.transfer
(());
let
(
request
,
context
)
=
request
.transfer
(());
let
ctx
=
context
.context
();
let
ctx
=
context
.context
();
let
(
tx
,
mut
rx
)
=
channel
(
10_000
);
let
(
tx
,
mut
rx
)
=
channel
(
10_000
);
...
...
lib/llm/src/discovery/watcher.rs
View file @
e3f1bd5d
...
@@ -25,7 +25,7 @@ use crate::{
...
@@ -25,7 +25,7 @@ use crate::{
protocols
::
openai
::
chat_completions
::{
protocols
::
openai
::
chat_completions
::{
NvCreateChatCompletionRequest
,
NvCreateChatCompletionStreamResponse
,
NvCreateChatCompletionRequest
,
NvCreateChatCompletionStreamResponse
,
},
},
protocols
::
openai
::
completions
::{
CompletionRe
sponse
,
NvCreateCompletionRe
quest
},
protocols
::
openai
::
completions
::{
NvCreate
CompletionRe
quest
,
NvCreateCompletionRe
sponse
},
protocols
::
openai
::
embeddings
::{
NvCreateEmbeddingRequest
,
NvCreateEmbeddingResponse
},
protocols
::
openai
::
embeddings
::{
NvCreateEmbeddingRequest
,
NvCreateEmbeddingResponse
},
};
};
...
@@ -240,7 +240,7 @@ impl ModelWatcher {
...
@@ -240,7 +240,7 @@ impl ModelWatcher {
let
frontend
=
SegmentSource
::
<
let
frontend
=
SegmentSource
::
<
SingleIn
<
NvCreateCompletionRequest
>
,
SingleIn
<
NvCreateCompletionRequest
>
,
ManyOut
<
Annotated
<
CompletionResponse
>>
,
ManyOut
<
Annotated
<
NvCreate
CompletionResponse
>>
,
>
::
new
();
>
::
new
();
let
preprocessor
=
OpenAIPreprocessor
::
new
(
card
.clone
())
.await
?
.into_operator
();
let
preprocessor
=
OpenAIPreprocessor
::
new
(
card
.clone
())
.await
?
.into_operator
();
let
backend
=
Backend
::
from_mdc
(
card
.clone
())
.await
?
.into_operator
();
let
backend
=
Backend
::
from_mdc
(
card
.clone
())
.await
?
.into_operator
();
...
@@ -292,7 +292,7 @@ impl ModelWatcher {
...
@@ -292,7 +292,7 @@ impl ModelWatcher {
ModelType
::
Completion
=>
{
ModelType
::
Completion
=>
{
let
push_router
=
PushRouter
::
<
let
push_router
=
PushRouter
::
<
NvCreateCompletionRequest
,
NvCreateCompletionRequest
,
Annotated
<
CompletionResponse
>
,
Annotated
<
NvCreate
CompletionResponse
>
,
>
::
from_client
(
client
,
Default
::
default
())
>
::
from_client
(
client
,
Default
::
default
())
.await
?
;
.await
?
;
let
engine
=
Arc
::
new
(
push_router
);
let
engine
=
Arc
::
new
(
push_router
);
...
...
lib/llm/src/engines.rs
View file @
e3f1bd5d
...
@@ -30,7 +30,7 @@ use crate::preprocessor::PreprocessedRequest;
...
@@ -30,7 +30,7 @@ use crate::preprocessor::PreprocessedRequest;
use
crate
::
protocols
::
common
::
llm_backend
::
LLMEngineOutput
;
use
crate
::
protocols
::
common
::
llm_backend
::
LLMEngineOutput
;
use
crate
::
protocols
::
openai
::{
use
crate
::
protocols
::
openai
::{
chat_completions
::{
NvCreateChatCompletionRequest
,
NvCreateChatCompletionStreamResponse
},
chat_completions
::{
NvCreateChatCompletionRequest
,
NvCreateChatCompletionStreamResponse
},
completions
::{
prompt_to_string
,
CompletionRe
sponse
,
NvCreateCompletionRe
quest
},
completions
::{
prompt_to_string
,
NvCreate
CompletionRe
quest
,
NvCreateCompletionRe
sponse
},
};
};
use
crate
::
types
::
openai
::
embeddings
::
NvCreateEmbeddingRequest
;
use
crate
::
types
::
openai
::
embeddings
::
NvCreateEmbeddingRequest
;
use
crate
::
types
::
openai
::
embeddings
::
NvCreateEmbeddingResponse
;
use
crate
::
types
::
openai
::
embeddings
::
NvCreateEmbeddingResponse
;
...
@@ -142,7 +142,7 @@ pub trait StreamingEngine: Send + Sync {
...
@@ -142,7 +142,7 @@ pub trait StreamingEngine: Send + Sync {
async
fn
handle_completion
(
async
fn
handle_completion
(
&
self
,
&
self
,
req
:
SingleIn
<
NvCreateCompletionRequest
>
,
req
:
SingleIn
<
NvCreateCompletionRequest
>
,
)
->
Result
<
ManyOut
<
Annotated
<
CompletionResponse
>>
,
Error
>
;
)
->
Result
<
ManyOut
<
Annotated
<
NvCreate
CompletionResponse
>>
,
Error
>
;
async
fn
handle_chat
(
async
fn
handle_chat
(
&
self
,
&
self
,
...
@@ -219,13 +219,17 @@ impl
...
@@ -219,13 +219,17 @@ impl
}
}
#[async_trait]
#[async_trait]
impl
AsyncEngine
<
SingleIn
<
NvCreateCompletionRequest
>
,
ManyOut
<
Annotated
<
CompletionResponse
>>
,
Error
>
impl
for
EchoEngineFull
AsyncEngine
<
SingleIn
<
NvCreateCompletionRequest
>
,
ManyOut
<
Annotated
<
NvCreateCompletionResponse
>>
,
Error
,
>
for
EchoEngineFull
{
{
async
fn
generate
(
async
fn
generate
(
&
self
,
&
self
,
incoming_request
:
SingleIn
<
NvCreateCompletionRequest
>
,
incoming_request
:
SingleIn
<
NvCreateCompletionRequest
>
,
)
->
Result
<
ManyOut
<
Annotated
<
CompletionResponse
>>
,
Error
>
{
)
->
Result
<
ManyOut
<
Annotated
<
NvCreate
CompletionResponse
>>
,
Error
>
{
let
(
request
,
context
)
=
incoming_request
.transfer
(());
let
(
request
,
context
)
=
incoming_request
.transfer
(());
let
deltas
=
request
.response_generator
();
let
deltas
=
request
.response_generator
();
let
ctx
=
context
.context
();
let
ctx
=
context
.context
();
...
@@ -268,7 +272,7 @@ impl<E> StreamingEngine for EngineDispatcher<E>
...
@@ -268,7 +272,7 @@ impl<E> StreamingEngine for EngineDispatcher<E>
where
where
E
:
AsyncEngine
<
E
:
AsyncEngine
<
SingleIn
<
NvCreateCompletionRequest
>
,
SingleIn
<
NvCreateCompletionRequest
>
,
ManyOut
<
Annotated
<
CompletionResponse
>>
,
ManyOut
<
Annotated
<
NvCreate
CompletionResponse
>>
,
Error
,
Error
,
>
+
AsyncEngine
<
>
+
AsyncEngine
<
SingleIn
<
NvCreateChatCompletionRequest
>
,
SingleIn
<
NvCreateChatCompletionRequest
>
,
...
@@ -284,7 +288,7 @@ where
...
@@ -284,7 +288,7 @@ where
async
fn
handle_completion
(
async
fn
handle_completion
(
&
self
,
&
self
,
req
:
SingleIn
<
NvCreateCompletionRequest
>
,
req
:
SingleIn
<
NvCreateCompletionRequest
>
,
)
->
Result
<
ManyOut
<
Annotated
<
CompletionResponse
>>
,
Error
>
{
)
->
Result
<
ManyOut
<
Annotated
<
NvCreate
CompletionResponse
>>
,
Error
>
{
self
.inner
.generate
(
req
)
.await
self
.inner
.generate
(
req
)
.await
}
}
...
@@ -347,13 +351,17 @@ impl StreamingEngineAdapter {
...
@@ -347,13 +351,17 @@ impl StreamingEngineAdapter {
}
}
#[async_trait]
#[async_trait]
impl
AsyncEngine
<
SingleIn
<
NvCreateCompletionRequest
>
,
ManyOut
<
Annotated
<
CompletionResponse
>>
,
Error
>
impl
for
StreamingEngineAdapter
AsyncEngine
<
SingleIn
<
NvCreateCompletionRequest
>
,
ManyOut
<
Annotated
<
NvCreateCompletionResponse
>>
,
Error
,
>
for
StreamingEngineAdapter
{
{
async
fn
generate
(
async
fn
generate
(
&
self
,
&
self
,
req
:
SingleIn
<
NvCreateCompletionRequest
>
,
req
:
SingleIn
<
NvCreateCompletionRequest
>
,
)
->
Result
<
ManyOut
<
Annotated
<
CompletionResponse
>>
,
Error
>
{
)
->
Result
<
ManyOut
<
Annotated
<
NvCreate
CompletionResponse
>>
,
Error
>
{
self
.0
.handle_completion
(
req
)
.await
self
.0
.handle_completion
(
req
)
.await
}
}
}
}
...
...
lib/llm/src/http/service/openai.rs
View file @
e3f1bd5d
...
@@ -30,7 +30,7 @@ use super::{
...
@@ -30,7 +30,7 @@ use super::{
use
crate
::
preprocessor
::
LLMMetricAnnotation
;
use
crate
::
preprocessor
::
LLMMetricAnnotation
;
use
crate
::
protocols
::
openai
::
embeddings
::{
NvCreateEmbeddingRequest
,
NvCreateEmbeddingResponse
};
use
crate
::
protocols
::
openai
::
embeddings
::{
NvCreateEmbeddingRequest
,
NvCreateEmbeddingResponse
};
use
crate
::
protocols
::
openai
::{
use
crate
::
protocols
::
openai
::{
chat_completions
::
NvCreateChatCompletionResponse
,
completions
::
CompletionResponse
,
chat_completions
::
NvCreateChatCompletionResponse
,
completions
::
NvCreate
CompletionResponse
,
};
};
use
crate
::
request_template
::
RequestTemplate
;
use
crate
::
request_template
::
RequestTemplate
;
use
crate
::
types
::{
use
crate
::
types
::{
...
@@ -193,7 +193,7 @@ async fn completions(
...
@@ -193,7 +193,7 @@ async fn completions(
Ok
(
sse_stream
.into_response
())
Ok
(
sse_stream
.into_response
())
}
else
{
}
else
{
// TODO: report ISL/OSL for non-streaming requests
// TODO: report ISL/OSL for non-streaming requests
let
response
=
CompletionResponse
::
from_annotated_stream
(
stream
.into
())
let
response
=
NvCreate
CompletionResponse
::
from_annotated_stream
(
stream
.into
())
.await
.await
.map_err
(|
e
|
{
.map_err
(|
e
|
{
tracing
::
error!
(
tracing
::
error!
(
...
...
lib/llm/src/preprocessor.rs
View file @
e3f1bd5d
...
@@ -46,7 +46,7 @@ use crate::protocols::{
...
@@ -46,7 +46,7 @@ use crate::protocols::{
common
::{
SamplingOptionsProvider
,
StopConditionsProvider
},
common
::{
SamplingOptionsProvider
,
StopConditionsProvider
},
openai
::{
openai
::{
chat_completions
::{
NvCreateChatCompletionRequest
,
NvCreateChatCompletionStreamResponse
},
chat_completions
::{
NvCreateChatCompletionRequest
,
NvCreateChatCompletionStreamResponse
},
completions
::{
CompletionRe
sponse
,
NvCreateCompletionRe
quest
},
completions
::{
NvCreate
CompletionRe
quest
,
NvCreateCompletionRe
sponse
},
nvext
::
NvExtProvider
,
nvext
::
NvExtProvider
,
DeltaGeneratorExt
,
DeltaGeneratorExt
,
},
},
...
@@ -433,7 +433,7 @@ impl
...
@@ -433,7 +433,7 @@ impl
impl
impl
Operator
<
Operator
<
SingleIn
<
NvCreateCompletionRequest
>
,
SingleIn
<
NvCreateCompletionRequest
>
,
ManyOut
<
Annotated
<
CompletionResponse
>>
,
ManyOut
<
Annotated
<
NvCreate
CompletionResponse
>>
,
SingleIn
<
PreprocessedRequest
>
,
SingleIn
<
PreprocessedRequest
>
,
ManyOut
<
Annotated
<
BackendOutput
>>
,
ManyOut
<
Annotated
<
BackendOutput
>>
,
>
for
OpenAIPreprocessor
>
for
OpenAIPreprocessor
...
@@ -448,7 +448,7 @@ impl
...
@@ -448,7 +448,7 @@ impl
Error
,
Error
,
>
,
>
,
>
,
>
,
)
->
Result
<
ManyOut
<
Annotated
<
CompletionResponse
>>
,
Error
>
{
)
->
Result
<
ManyOut
<
Annotated
<
NvCreate
CompletionResponse
>>
,
Error
>
{
// unpack the request
// unpack the request
let
(
request
,
context
)
=
request
.into_parts
();
let
(
request
,
context
)
=
request
.into_parts
();
...
@@ -465,7 +465,7 @@ impl
...
@@ -465,7 +465,7 @@ impl
let
common_request
=
context
.map
(|
_
|
common_request
);
let
common_request
=
context
.map
(|
_
|
common_request
);
// create a stream of annotations this will be prepend to the response stream
// create a stream of annotations this will be prepend to the response stream
let
annotations
:
Vec
<
Annotated
<
CompletionResponse
>>
=
annotations
let
annotations
:
Vec
<
Annotated
<
NvCreate
CompletionResponse
>>
=
annotations
.into_iter
()
.into_iter
()
.flat_map
(|(
k
,
v
)|
Annotated
::
from_annotation
(
k
,
&
v
))
.flat_map
(|(
k
,
v
)|
Annotated
::
from_annotation
(
k
,
&
v
))
.collect
();
.collect
();
...
...
lib/llm/src/protocols/openai/completions.rs
View file @
e3f1bd5d
...
@@ -39,41 +39,10 @@ pub struct NvCreateCompletionRequest {
...
@@ -39,41 +39,10 @@ pub struct NvCreateCompletionRequest {
pub
nvext
:
Option
<
NvExt
>
,
pub
nvext
:
Option
<
NvExt
>
,
}
}
/// Legacy OpenAI CompletionResponse
#[derive(Serialize,
Deserialize,
Validate,
Debug,
Clone)]
/// Represents a completion response from the API.
pub
struct
NvCreateCompletionResponse
{
/// Note: both the streamed and non-streamed response objects share the same
#[serde(flatten)]
/// shape (unlike the chat endpoint).
pub
inner
:
async_openai
::
types
::
CreateCompletionResponse
,
#[derive(Clone,
Debug,
Deserialize,
Serialize)]
pub
struct
CompletionResponse
{
/// A unique identifier for the completion.
pub
id
:
String
,
/// The list of completion choices the model generated for the input prompt.
pub
choices
:
Vec
<
async_openai
::
types
::
Choice
>
,
/// The Unix timestamp (in seconds) of when the completion was created.
pub
created
:
u64
,
/// The model used for completion.
pub
model
:
String
,
/// The object type, which is always "text_completion"
pub
object
:
String
,
/// Usage statistics for the completion request.
pub
usage
:
Option
<
async_openai
::
types
::
CompletionUsage
>
,
/// This fingerprint represents the backend configuration that the model runs with.
/// Can be used in conjunction with the seed request parameter to understand when backend
/// changes have been made that might impact determinism.
///
/// NIM Compatibility:
/// This field is not supported by the NIM; however it will be added in the future.
/// The optional nature of this field will be relaxed when it is supported.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
system_fingerprint
:
Option
<
String
>
,
// TODO(ryan)
// pub nvext: Option<NimResponseExt>,
}
}
impl
ContentProvider
for
async_openai
::
types
::
Choice
{
impl
ContentProvider
for
async_openai
::
types
::
Choice
{
...
@@ -205,16 +174,17 @@ impl ResponseFactory {
...
@@ -205,16 +174,17 @@ impl ResponseFactory {
&
self
,
&
self
,
choice
:
async_openai
::
types
::
Choice
,
choice
:
async_openai
::
types
::
Choice
,
usage
:
Option
<
async_openai
::
types
::
CompletionUsage
>
,
usage
:
Option
<
async_openai
::
types
::
CompletionUsage
>
,
)
->
CompletionResponse
{
)
->
NvCreate
CompletionResponse
{
CompletionResponse
{
let
inner
=
async_openai
::
types
::
Create
CompletionResponse
{
id
:
self
.id
.clone
(),
id
:
self
.id
.clone
(),
object
:
self
.object
.clone
(),
object
:
self
.object
.clone
(),
created
:
self
.created
,
created
:
self
.created
as
u32
,
model
:
self
.model
.clone
(),
model
:
self
.model
.clone
(),
choices
:
vec!
[
choice
],
choices
:
vec!
[
choice
],
system_fingerprint
:
self
.system_fingerprint
.clone
(),
system_fingerprint
:
self
.system_fingerprint
.clone
(),
usage
,
usage
,
}
};
NvCreateCompletionResponse
{
inner
}
}
}
}
}
...
...
lib/llm/src/protocols/openai/completions/aggregator.rs
View file @
e3f1bd5d
...
@@ -18,7 +18,7 @@ use std::collections::HashMap;
...
@@ -18,7 +18,7 @@ use std::collections::HashMap;
use
anyhow
::
Result
;
use
anyhow
::
Result
;
use
futures
::
StreamExt
;
use
futures
::
StreamExt
;
use
super
::
CompletionResponse
;
use
super
::
NvCreate
CompletionResponse
;
use
crate
::
protocols
::{
use
crate
::
protocols
::{
codec
::{
Message
,
SseCodecError
},
codec
::{
Message
,
SseCodecError
},
common
::
FinishReason
,
common
::
FinishReason
,
...
@@ -64,8 +64,8 @@ impl DeltaAggregator {
...
@@ -64,8 +64,8 @@ impl DeltaAggregator {
/// Aggregates a stream of [`Annotated<CompletionResponse>`]s into a single [`CompletionResponse`].
/// Aggregates a stream of [`Annotated<CompletionResponse>`]s into a single [`CompletionResponse`].
pub
async
fn
apply
(
pub
async
fn
apply
(
stream
:
DataStream
<
Annotated
<
CompletionResponse
>>
,
stream
:
DataStream
<
Annotated
<
NvCreate
CompletionResponse
>>
,
)
->
Result
<
CompletionResponse
>
{
)
->
Result
<
NvCreate
CompletionResponse
>
{
let
aggregator
=
stream
let
aggregator
=
stream
.fold
(
DeltaAggregator
::
new
(),
|
mut
aggregator
,
delta
|
async
move
{
.fold
(
DeltaAggregator
::
new
(),
|
mut
aggregator
,
delta
|
async
move
{
let
delta
=
match
delta
.ok
()
{
let
delta
=
match
delta
.ok
()
{
...
@@ -83,18 +83,18 @@ impl DeltaAggregator {
...
@@ -83,18 +83,18 @@ impl DeltaAggregator {
// these are cheap to move so we do it every time since we are consuming the delta
// these are cheap to move so we do it every time since we are consuming the delta
let
delta
=
delta
.data
.unwrap
();
let
delta
=
delta
.data
.unwrap
();
aggregator
.id
=
delta
.id
;
aggregator
.id
=
delta
.
inner.
id
;
aggregator
.model
=
delta
.model
;
aggregator
.model
=
delta
.
inner.
model
;
aggregator
.created
=
delta
.created
;
aggregator
.created
=
delta
.
inner.
created
as
u64
;
if
let
Some
(
usage
)
=
delta
.usage
{
if
let
Some
(
usage
)
=
delta
.
inner.
usage
{
aggregator
.usage
=
Some
(
usage
);
aggregator
.usage
=
Some
(
usage
);
}
}
if
let
Some
(
system_fingerprint
)
=
delta
.system_fingerprint
{
if
let
Some
(
system_fingerprint
)
=
delta
.
inner.
system_fingerprint
{
aggregator
.system_fingerprint
=
Some
(
system_fingerprint
);
aggregator
.system_fingerprint
=
Some
(
system_fingerprint
);
}
}
// handle the choices
// handle the choices
for
choice
in
delta
.choices
{
for
choice
in
delta
.
inner.
choices
{
let
state_choice
=
let
state_choice
=
aggregator
aggregator
.choices
.choices
...
@@ -145,15 +145,19 @@ impl DeltaAggregator {
...
@@ -145,15 +145,19 @@ impl DeltaAggregator {
choices
.sort_by
(|
a
,
b
|
a
.index
.cmp
(
&
b
.index
));
choices
.sort_by
(|
a
,
b
|
a
.index
.cmp
(
&
b
.index
));
Ok
(
CompletionResponse
{
let
inner
=
async_openai
::
types
::
Create
CompletionResponse
{
id
:
aggregator
.id
,
id
:
aggregator
.id
,
created
:
aggregator
.created
,
created
:
aggregator
.created
as
u32
,
usage
:
aggregator
.usage
,
usage
:
aggregator
.usage
,
model
:
aggregator
.model
,
model
:
aggregator
.model
,
object
:
"text_completion"
.to_string
(),
object
:
"text_completion"
.to_string
(),
system_fingerprint
:
aggregator
.system_fingerprint
,
system_fingerprint
:
aggregator
.system_fingerprint
,
choices
,
choices
,
})
};
let
response
=
NvCreateCompletionResponse
{
inner
};
Ok
(
response
)
}
}
}
}
...
@@ -170,17 +174,17 @@ impl From<DeltaChoice> for async_openai::types::Choice {
...
@@ -170,17 +174,17 @@ impl From<DeltaChoice> for async_openai::types::Choice {
}
}
}
}
impl
CompletionResponse
{
impl
NvCreate
CompletionResponse
{
pub
async
fn
from_sse_stream
(
pub
async
fn
from_sse_stream
(
stream
:
DataStream
<
Result
<
Message
,
SseCodecError
>>
,
stream
:
DataStream
<
Result
<
Message
,
SseCodecError
>>
,
)
->
Result
<
CompletionResponse
>
{
)
->
Result
<
NvCreate
CompletionResponse
>
{
let
stream
=
convert_sse_stream
::
<
CompletionResponse
>
(
stream
);
let
stream
=
convert_sse_stream
::
<
NvCreate
CompletionResponse
>
(
stream
);
CompletionResponse
::
from_annotated_stream
(
stream
)
.await
NvCreate
CompletionResponse
::
from_annotated_stream
(
stream
)
.await
}
}
pub
async
fn
from_annotated_stream
(
pub
async
fn
from_annotated_stream
(
stream
:
DataStream
<
Annotated
<
CompletionResponse
>>
,
stream
:
DataStream
<
Annotated
<
NvCreate
CompletionResponse
>>
,
)
->
Result
<
CompletionResponse
>
{
)
->
Result
<
NvCreate
CompletionResponse
>
{
DeltaAggregator
::
apply
(
stream
)
.await
DeltaAggregator
::
apply
(
stream
)
.await
}
}
}
}
...
@@ -192,13 +196,13 @@ mod tests {
...
@@ -192,13 +196,13 @@ mod tests {
use
futures
::
stream
;
use
futures
::
stream
;
use
super
::
*
;
use
super
::
*
;
use
crate
::
protocols
::
openai
::
completions
::
CompletionResponse
;
use
crate
::
protocols
::
openai
::
completions
::
NvCreate
CompletionResponse
;
fn
create_test_delta
(
fn
create_test_delta
(
index
:
u64
,
index
:
u64
,
text
:
&
str
,
text
:
&
str
,
finish_reason
:
Option
<
String
>
,
finish_reason
:
Option
<
String
>
,
)
->
Annotated
<
CompletionResponse
>
{
)
->
Annotated
<
NvCreate
CompletionResponse
>
{
// This will silently discard invalid_finish reason values and fall back
// This will silently discard invalid_finish reason values and fall back
// to None - totally fine since this is test code
// to None - totally fine since this is test code
let
finish_reason
=
finish_reason
let
finish_reason
=
finish_reason
...
@@ -206,21 +210,25 @@ mod tests {
...
@@ -206,21 +210,25 @@ mod tests {
.and_then
(|
s
|
FinishReason
::
from_str
(
s
)
.ok
())
.and_then
(|
s
|
FinishReason
::
from_str
(
s
)
.ok
())
.map
(
Into
::
into
);
.map
(
Into
::
into
);
let
inner
=
async_openai
::
types
::
CreateCompletionResponse
{
id
:
"test_id"
.to_string
(),
model
:
"meta/llama-3.1-8b"
.to_string
(),
created
:
1234567890
,
usage
:
None
,
system_fingerprint
:
None
,
choices
:
vec!
[
async_openai
::
types
::
Choice
{
index
:
index
as
u32
,
text
:
text
.to_string
(),
finish_reason
,
logprobs
:
None
,
}],
object
:
"text_completion"
.to_string
(),
};
let
response
=
NvCreateCompletionResponse
{
inner
};
Annotated
{
Annotated
{
data
:
Some
(
CompletionResponse
{
data
:
Some
(
response
),
id
:
"test_id"
.to_string
(),
model
:
"meta/llama-3.1-8b"
.to_string
(),
created
:
1234567890
,
usage
:
None
,
system_fingerprint
:
None
,
choices
:
vec!
[
async_openai
::
types
::
Choice
{
index
:
index
as
u32
,
text
:
text
.to_string
(),
finish_reason
,
logprobs
:
None
,
}],
object
:
"text_completion"
.to_string
(),
}),
id
:
Some
(
"test_id"
.to_string
()),
id
:
Some
(
"test_id"
.to_string
()),
event
:
None
,
event
:
None
,
comment
:
None
,
comment
:
None
,
...
@@ -230,7 +238,7 @@ mod tests {
...
@@ -230,7 +238,7 @@ mod tests {
#[tokio::test]
#[tokio::test]
async
fn
test_empty_stream
()
{
async
fn
test_empty_stream
()
{
// Create an empty stream
// Create an empty stream
let
stream
:
DataStream
<
Annotated
<
CompletionResponse
>>
=
Box
::
pin
(
stream
::
empty
());
let
stream
:
DataStream
<
Annotated
<
NvCreate
CompletionResponse
>>
=
Box
::
pin
(
stream
::
empty
());
// Call DeltaAggregator::apply
// Call DeltaAggregator::apply
let
result
=
DeltaAggregator
::
apply
(
stream
)
.await
;
let
result
=
DeltaAggregator
::
apply
(
stream
)
.await
;
...
@@ -240,12 +248,12 @@ mod tests {
...
@@ -240,12 +248,12 @@ mod tests {
let
response
=
result
.unwrap
();
let
response
=
result
.unwrap
();
// Verify that the response is empty and has default values
// Verify that the response is empty and has default values
assert_eq!
(
response
.id
,
""
);
assert_eq!
(
response
.
inner.
id
,
""
);
assert_eq!
(
response
.model
,
""
);
assert_eq!
(
response
.
inner.
model
,
""
);
assert_eq!
(
response
.created
,
0
);
assert_eq!
(
response
.
inner.
created
,
0
);
assert
!
(
response
.usage
.is_none
());
assert
!
(
response
.
inner.
usage
.is_none
());
assert
!
(
response
.system_fingerprint
.is_none
());
assert
!
(
response
.
inner.
system_fingerprint
.is_none
());
assert_eq!
(
response
.choices
.len
(),
0
);
assert_eq!
(
response
.
inner.
choices
.len
(),
0
);
}
}
#[tokio::test]
#[tokio::test]
...
@@ -264,19 +272,23 @@ mod tests {
...
@@ -264,19 +272,23 @@ mod tests {
let
response
=
result
.unwrap
();
let
response
=
result
.unwrap
();
// Verify the response fields
// Verify the response fields
assert_eq!
(
response
.id
,
"test_id"
);
assert_eq!
(
response
.
inner.
id
,
"test_id"
);
assert_eq!
(
response
.model
,
"meta/llama-3.1-8b"
);
assert_eq!
(
response
.
inner.
model
,
"meta/llama-3.1-8b"
);
assert_eq!
(
response
.created
,
1234567890
);
assert_eq!
(
response
.
inner.
created
,
1234567890
);
assert
!
(
response
.usage
.is_none
());
assert
!
(
response
.
inner.
usage
.is_none
());
assert
!
(
response
.system_fingerprint
.is_none
());
assert
!
(
response
.
inner.
system_fingerprint
.is_none
());
assert_eq!
(
response
.choices
.len
(),
1
);
assert_eq!
(
response
.
inner.
choices
.len
(),
1
);
let
choice
=
&
response
.choices
[
0
];
let
choice
=
&
response
.
inner.
choices
[
0
];
assert_eq!
(
choice
.index
,
0
);
assert_eq!
(
choice
.index
,
0
);
assert_eq!
(
choice
.text
,
"Hello,"
.to_string
());
assert_eq!
(
choice
.text
,
"Hello,"
.to_string
());
assert_eq!
(
assert_eq!
(
choice
.finish_reason
,
choice
.finish_reason
,
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Length
)
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Length
)
);
);
assert_eq!
(
choice
.finish_reason
,
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Length
)
);
assert
!
(
choice
.logprobs
.is_none
());
assert
!
(
choice
.logprobs
.is_none
());
}
}
...
@@ -300,42 +312,50 @@ mod tests {
...
@@ -300,42 +312,50 @@ mod tests {
let
response
=
result
.unwrap
();
let
response
=
result
.unwrap
();
// Verify the response fields
// Verify the response fields
assert_eq!
(
response
.choices
.len
(),
1
);
assert_eq!
(
response
.
inner.
choices
.len
(),
1
);
let
choice
=
&
response
.choices
[
0
];
let
choice
=
&
response
.
inner.
choices
[
0
];
assert_eq!
(
choice
.index
,
0
);
assert_eq!
(
choice
.index
,
0
);
assert_eq!
(
choice
.text
,
"Hello, world!"
.to_string
());
assert_eq!
(
choice
.text
,
"Hello, world!"
.to_string
());
assert_eq!
(
assert_eq!
(
choice
.finish_reason
,
choice
.finish_reason
,
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Stop
)
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Stop
)
);
);
assert_eq!
(
choice
.finish_reason
,
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Stop
)
);
}
}
#[tokio::test]
#[tokio::test]
async
fn
test_multiple_choices
()
{
async
fn
test_multiple_choices
()
{
// Create a delta with multiple choices
// Create a delta with multiple choices
let
inner
=
async_openai
::
types
::
CreateCompletionResponse
{
id
:
"test_id"
.to_string
(),
model
:
"meta/llama-3.1-8b"
.to_string
(),
created
:
1234567890
,
usage
:
None
,
system_fingerprint
:
None
,
choices
:
vec!
[
async_openai
::
types
::
Choice
{
index
:
0
,
text
:
"Choice 0"
.to_string
(),
finish_reason
:
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Stop
),
logprobs
:
None
,
},
async_openai
::
types
::
Choice
{
index
:
1
,
text
:
"Choice 1"
.to_string
(),
finish_reason
:
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Stop
),
logprobs
:
None
,
},
],
object
:
"text_completion"
.to_string
(),
};
let
response
=
NvCreateCompletionResponse
{
inner
};
let
annotated_delta
=
Annotated
{
let
annotated_delta
=
Annotated
{
data
:
Some
(
CompletionResponse
{
data
:
Some
(
response
),
id
:
"test_id"
.to_string
(),
model
:
"meta/llama-3.1-8b"
.to_string
(),
created
:
1234567890
,
usage
:
None
,
system_fingerprint
:
None
,
choices
:
vec!
[
async_openai
::
types
::
Choice
{
index
:
0
,
text
:
"Choice 0"
.to_string
(),
finish_reason
:
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Stop
),
logprobs
:
None
,
},
async_openai
::
types
::
Choice
{
index
:
1
,
text
:
"Choice 1"
.to_string
(),
finish_reason
:
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Stop
),
logprobs
:
None
,
},
],
object
:
"text_completion"
.to_string
(),
}),
id
:
Some
(
"test_id"
.to_string
()),
id
:
Some
(
"test_id"
.to_string
()),
event
:
None
,
event
:
None
,
comment
:
None
,
comment
:
None
,
...
@@ -352,22 +372,30 @@ mod tests {
...
@@ -352,22 +372,30 @@ mod tests {
let
mut
response
=
result
.unwrap
();
let
mut
response
=
result
.unwrap
();
// Verify the response fields
// Verify the response fields
assert_eq!
(
response
.choices
.len
(),
2
);
assert_eq!
(
response
.
inner.
choices
.len
(),
2
);
response
.choices
.sort_by
(|
a
,
b
|
a
.index
.cmp
(
&
b
.index
));
// Ensure the choices are ordered
response
.
inner.
choices
.sort_by
(|
a
,
b
|
a
.index
.cmp
(
&
b
.index
));
// Ensure the choices are ordered
let
choice0
=
&
response
.choices
[
0
];
let
choice0
=
&
response
.
inner.
choices
[
0
];
assert_eq!
(
choice0
.index
,
0
);
assert_eq!
(
choice0
.index
,
0
);
assert_eq!
(
choice0
.text
,
"Choice 0"
.to_string
());
assert_eq!
(
choice0
.text
,
"Choice 0"
.to_string
());
assert_eq!
(
assert_eq!
(
choice0
.finish_reason
,
choice0
.finish_reason
,
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Stop
)
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Stop
)
);
);
assert_eq!
(
choice0
.finish_reason
,
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Stop
)
);
let
choice1
=
&
response
.choices
[
1
];
let
choice1
=
&
response
.
inner.
choices
[
1
];
assert_eq!
(
choice1
.index
,
1
);
assert_eq!
(
choice1
.index
,
1
);
assert_eq!
(
choice1
.text
,
"Choice 1"
.to_string
());
assert_eq!
(
choice1
.text
,
"Choice 1"
.to_string
());
assert_eq!
(
assert_eq!
(
choice1
.finish_reason
,
choice1
.finish_reason
,
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Stop
)
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Stop
)
);
);
assert_eq!
(
choice1
.finish_reason
,
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Stop
)
);
}
}
}
}
lib/llm/src/protocols/openai/completions/delta.rs
View file @
e3f1bd5d
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
use
super
::{
CompletionRe
sponse
,
NvCreateCompletionRe
quest
};
use
super
::{
NvCreate
CompletionRe
quest
,
NvCreateCompletionRe
sponse
};
use
crate
::
protocols
::
common
;
use
crate
::
protocols
::
common
;
impl
NvCreateCompletionRequest
{
impl
NvCreateCompletionRequest
{
...
@@ -83,7 +83,7 @@ impl DeltaGenerator {
...
@@ -83,7 +83,7 @@ impl DeltaGenerator {
index
:
u64
,
index
:
u64
,
text
:
Option
<
String
>
,
text
:
Option
<
String
>
,
finish_reason
:
Option
<
async_openai
::
types
::
CompletionFinishReason
>
,
finish_reason
:
Option
<
async_openai
::
types
::
CompletionFinishReason
>
,
)
->
CompletionResponse
{
)
->
NvCreate
CompletionResponse
{
// todo - update for tool calling
// todo - update for tool calling
let
mut
usage
=
self
.usage
.clone
();
let
mut
usage
=
self
.usage
.clone
();
...
@@ -91,10 +91,10 @@ impl DeltaGenerator {
...
@@ -91,10 +91,10 @@ impl DeltaGenerator {
usage
.total_tokens
=
usage
.prompt_tokens
+
usage
.completion_tokens
;
usage
.total_tokens
=
usage
.prompt_tokens
+
usage
.completion_tokens
;
}
}
CompletionResponse
{
let
inner
=
async_openai
::
types
::
Create
CompletionResponse
{
id
:
self
.id
.clone
(),
id
:
self
.id
.clone
(),
object
:
self
.object
.clone
(),
object
:
self
.object
.clone
(),
created
:
self
.created
,
created
:
self
.created
as
u32
,
model
:
self
.model
.clone
(),
model
:
self
.model
.clone
(),
system_fingerprint
:
self
.system_fingerprint
.clone
(),
system_fingerprint
:
self
.system_fingerprint
.clone
(),
choices
:
vec!
[
async_openai
::
types
::
Choice
{
choices
:
vec!
[
async_openai
::
types
::
Choice
{
...
@@ -108,15 +108,17 @@ impl DeltaGenerator {
...
@@ -108,15 +108,17 @@ impl DeltaGenerator {
}
else
{
}
else
{
None
None
},
},
}
};
NvCreateCompletionResponse
{
inner
}
}
}
}
}
impl
crate
::
protocols
::
openai
::
DeltaGeneratorExt
<
CompletionResponse
>
for
DeltaGenerator
{
impl
crate
::
protocols
::
openai
::
DeltaGeneratorExt
<
NvCreate
CompletionResponse
>
for
DeltaGenerator
{
fn
choice_from_postprocessor
(
fn
choice_from_postprocessor
(
&
mut
self
,
&
mut
self
,
delta
:
common
::
llm_backend
::
BackendOutput
,
delta
:
common
::
llm_backend
::
BackendOutput
,
)
->
anyhow
::
Result
<
CompletionResponse
>
{
)
->
anyhow
::
Result
<
NvCreate
CompletionResponse
>
{
// aggregate usage
// aggregate usage
if
self
.options.enable_usage
{
if
self
.options.enable_usage
{
self
.usage.completion_tokens
+=
delta
.token_ids
.len
()
as
u32
;
self
.usage.completion_tokens
+=
delta
.token_ids
.len
()
as
u32
;
...
...
lib/llm/src/types.rs
View file @
e3f1bd5d
...
@@ -24,15 +24,17 @@ pub mod openai {
...
@@ -24,15 +24,17 @@ pub mod openai {
pub
mod
completions
{
pub
mod
completions
{
use
super
::
*
;
use
super
::
*
;
pub
use
protocols
::
openai
::
completions
::{
CompletionResponse
,
NvCreateCompletionRequest
};
pub
use
protocols
::
openai
::
completions
::{
NvCreateCompletionRequest
,
NvCreateCompletionResponse
,
};
/// A [`UnaryEngine`] implementation for the OpenAI Completions API
/// A [`UnaryEngine`] implementation for the OpenAI Completions API
pub
type
OpenAICompletionsUnaryEngine
=
pub
type
OpenAICompletionsUnaryEngine
=
UnaryEngine
<
NvCreateCompletionRequest
,
CompletionResponse
>
;
UnaryEngine
<
NvCreateCompletionRequest
,
NvCreate
CompletionResponse
>
;
/// A [`ServerStreamingEngine`] implementation for the OpenAI Completions API
/// A [`ServerStreamingEngine`] implementation for the OpenAI Completions API
pub
type
OpenAICompletionsStreamingEngine
=
pub
type
OpenAICompletionsStreamingEngine
=
ServerStreamingEngine
<
NvCreateCompletionRequest
,
Annotated
<
CompletionResponse
>>
;
ServerStreamingEngine
<
NvCreateCompletionRequest
,
Annotated
<
NvCreate
CompletionResponse
>>
;
}
}
pub
mod
chat_completions
{
pub
mod
chat_completions
{
...
...
lib/llm/tests/aggregators.rs
View file @
e3f1bd5d
...
@@ -15,7 +15,9 @@
...
@@ -15,7 +15,9 @@
use
dynamo_llm
::
protocols
::{
use
dynamo_llm
::
protocols
::{
codec
::{
create_message_stream
,
Message
,
SseCodecError
},
codec
::{
create_message_stream
,
Message
,
SseCodecError
},
openai
::{
chat_completions
::
NvCreateChatCompletionResponse
,
completions
::
CompletionResponse
},
openai
::{
chat_completions
::
NvCreateChatCompletionResponse
,
completions
::
NvCreateCompletionResponse
,
},
ContentProvider
,
DataStream
,
ContentProvider
,
DataStream
,
};
};
use
futures
::
StreamExt
;
use
futures
::
StreamExt
;
...
@@ -112,13 +114,13 @@ async fn test_openai_chat_edge_case_invalid_deserialize_error() {
...
@@ -112,13 +114,13 @@ async fn test_openai_chat_edge_case_invalid_deserialize_error() {
#[tokio::test]
#[tokio::test]
async
fn
test_openai_cmpl_stream
()
{
async
fn
test_openai_cmpl_stream
()
{
let
stream
=
create_stream
(
CMPL_ROOT_PATH
,
"completion.streaming.1"
)
.take
(
16
);
let
stream
=
create_stream
(
CMPL_ROOT_PATH
,
"completion.streaming.1"
)
.take
(
16
);
let
result
=
CompletionResponse
::
from_sse_stream
(
Box
::
pin
(
stream
))
let
result
=
NvCreate
CompletionResponse
::
from_sse_stream
(
Box
::
pin
(
stream
))
.await
.await
.unwrap
();
.unwrap
();
// todo: provide a cleaner way to extract the content from choices
// todo: provide a cleaner way to extract the content from choices
assert_eq!
(
assert_eq!
(
result
.choices
.first
()
.unwrap
()
.content
(),
result
.
inner.
choices
.first
()
.unwrap
()
.content
(),
" This is a question that is often asked by those outside of AI research and development"
" This is a question that is often asked by those outside of AI research and development"
);
);
}
}
lib/llm/tests/http-service.rs
View file @
e3f1bd5d
...
@@ -24,7 +24,7 @@ use dynamo_llm::http::service::{
...
@@ -24,7 +24,7 @@ use dynamo_llm::http::service::{
use
dynamo_llm
::
protocols
::{
use
dynamo_llm
::
protocols
::{
openai
::{
openai
::{
chat_completions
::{
NvCreateChatCompletionRequest
,
NvCreateChatCompletionStreamResponse
},
chat_completions
::{
NvCreateChatCompletionRequest
,
NvCreateChatCompletionStreamResponse
},
completions
::{
CompletionRe
sponse
,
NvCreateCompletionRe
quest
},
completions
::{
NvCreate
CompletionRe
quest
,
NvCreateCompletionRe
sponse
},
},
},
Annotated
,
Annotated
,
};
};
...
@@ -101,13 +101,17 @@ impl
...
@@ -101,13 +101,17 @@ impl
}
}
#[async_trait]
#[async_trait]
impl
AsyncEngine
<
SingleIn
<
NvCreateCompletionRequest
>
,
ManyOut
<
Annotated
<
CompletionResponse
>>
,
Error
>
impl
for
AlwaysFailEngine
AsyncEngine
<
SingleIn
<
NvCreateCompletionRequest
>
,
ManyOut
<
Annotated
<
NvCreateCompletionResponse
>>
,
Error
,
>
for
AlwaysFailEngine
{
{
async
fn
generate
(
async
fn
generate
(
&
self
,
&
self
,
_
request
:
SingleIn
<
NvCreateCompletionRequest
>
,
_
request
:
SingleIn
<
NvCreateCompletionRequest
>
,
)
->
Result
<
ManyOut
<
Annotated
<
CompletionResponse
>>
,
Error
>
{
)
->
Result
<
ManyOut
<
Annotated
<
NvCreate
CompletionResponse
>>
,
Error
>
{
Err
(
HttpError
{
Err
(
HttpError
{
code
:
401
,
code
:
401
,
message
:
"Always fail"
.to_string
(),
message
:
"Always fail"
.to_string
(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment