Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
86aff237
Commit
86aff237
authored
Feb 26, 2025
by
Paul Hendricks
Committed by
GitHub
Feb 26, 2025
Browse files
refactor: using async_openai
Co-authored-by:
Graham King
<
grahamk@nvidia.com
>
parent
d694ca6e
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
102 additions
and
32 deletions
+102
-32
lib/llm/src/protocols/openai/completions/aggregator.rs
lib/llm/src/protocols/openai/completions/aggregator.rs
+6
-9
lib/llm/src/types.rs
lib/llm/src/types.rs
+4
-0
lib/llm/tests/aggregators.rs
lib/llm/tests/aggregators.rs
+34
-3
lib/llm/tests/http-service.rs
lib/llm/tests/http-service.rs
+35
-6
lib/llm/tests/openai_completions.rs
lib/llm/tests/openai_completions.rs
+1
-1
lib/llm/tests/preprocessor.rs
lib/llm/tests/preprocessor.rs
+22
-13
No files found.
lib/llm/src/protocols/openai/completions/aggregator.rs
View file @
86aff237
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -21,6 +21,7 @@ use futures::StreamExt;
use
super
::{
CompletionChoice
,
CompletionResponse
,
CompletionUsage
,
LogprobResult
};
use
crate
::
protocols
::{
codec
::{
Message
,
SseCodecError
},
common
::
FinishReason
,
convert_sse_stream
,
Annotated
,
DataStream
,
};
...
...
@@ -38,7 +39,7 @@ pub struct DeltaAggregator {
struct
DeltaChoice
{
index
:
u64
,
text
:
String
,
finish_reason
:
Option
<
crate
::
protocols
::
openai
::
chat_completions
::
FinishReason
>
,
finish_reason
:
Option
<
FinishReason
>
,
logprobs
:
Option
<
LogprobResult
>
,
}
...
...
@@ -110,11 +111,7 @@ impl DeltaAggregator {
// todo - handle logprobs
if
let
Some
(
finish_reason
)
=
choice
.finish_reason
{
let
reason
=
crate
::
protocols
::
openai
::
chat_completions
::
FinishReason
::
from_str
(
&
finish_reason
,
)
.ok
();
let
reason
=
FinishReason
::
from_str
(
&
finish_reason
)
.ok
();
state_choice
.finish_reason
=
reason
;
}
}
...
...
lib/llm/src/types.rs
View file @
86aff237
...
...
@@ -37,6 +37,10 @@ pub mod openai {
pub
mod
chat_completions
{
use
super
::
*
;
// pub use async_openai::types::CreateChatCompletionRequest as ChatCompletionRequest;
// pub use protocols::openai::chat_completions::{
// ChatCompletionResponse, ChatCompletionResponseDelta,
// };
pub
use
protocols
::
openai
::
chat_completions
::{
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionResponseDelta
,
};
...
...
lib/llm/tests/aggregators.rs
View file @
86aff237
...
...
@@ -40,8 +40,17 @@ async fn test_openai_chat_stream() {
// todo: provide a cleaner way to extract the content from choices
assert_eq!
(
result
.choices
.first
()
.unwrap
()
.content
(),
result
.inner
.choices
.first
()
.unwrap
()
.message
.content
.clone
()
.expect
(
"there to be content"
),
"Deep learning is a subfield of machine learning that involves the use of artificial"
.to_string
()
);
}
...
...
@@ -52,7 +61,18 @@ async fn test_openai_chat_edge_case_multi_line_data() {
.await
.unwrap
();
assert_eq!
(
result
.choices
.first
()
.unwrap
()
.content
(),
"Deep learning"
);
assert_eq!
(
result
.inner
.choices
.first
()
.unwrap
()
.message
.content
.clone
()
.expect
(
"there to be content"
),
"Deep learning"
.to_string
()
);
}
#[tokio::test]
...
...
@@ -62,7 +82,18 @@ async fn test_openai_chat_edge_case_comments_per_response() {
.await
.unwrap
();
assert_eq!
(
result
.choices
.first
()
.unwrap
()
.content
(),
"Deep learning"
);
assert_eq!
(
result
.inner
.choices
.first
()
.unwrap
()
.message
.content
.clone
()
.expect
(
"there to be content"
),
"Deep learning"
.to_string
()
);
}
#[tokio::test]
...
...
lib/llm/tests/http-service.rs
View file @
86aff237
...
...
@@ -40,6 +40,7 @@ use triton_distributed_runtime::{
struct
CounterEngine
{}
#[allow(deprecated)]
#[async_trait]
impl
AsyncEngine
<
...
...
@@ -55,7 +56,8 @@ impl
let
(
request
,
context
)
=
request
.transfer
(());
let
ctx
=
context
.context
();
let
max_tokens
=
request
.max_tokens
.unwrap_or
(
0
)
as
u64
;
// ALLOW: max_tokens is deprecated in favor of completion_usage_tokens
let
max_tokens
=
request
.inner.max_tokens
.unwrap_or
(
0
)
as
u64
;
// let generator = ChatCompletionResponseDelta::generator(request.model.clone());
let
generator
=
request
.response_generator
();
...
...
@@ -63,8 +65,13 @@ impl
let
stream
=
stream!
{
tokio
::
time
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
max_tokens
))
.await
;
for
i
in
0
..
10
{
let
choice
=
generator
.create_choice
(
i
as
u64
,
Some
(
format!
(
"choice {i}"
)),
None
,
None
);
yield
Annotated
::
from_data
(
choice
);
let
inner
=
generator
.create_choice
(
i
,
Some
(
format!
(
"choice {i}"
)),
None
,
None
);
let
output
=
ChatCompletionResponseDelta
{
inner
,
};
yield
Annotated
::
from_data
(
output
);
}
};
...
...
@@ -174,6 +181,7 @@ fn inc_counter(
expected
[
index
]
+=
1
;
}
#[allow(deprecated)]
#[tokio::test]
async
fn
test_http_service
()
{
let
service
=
HttpService
::
builder
()
.port
(
8989
)
.build
()
.unwrap
();
...
...
@@ -207,14 +215,31 @@ async fn test_http_service() {
let
client
=
reqwest
::
Client
::
new
();
let
mut
request
=
ChatCompletionRequest
::
builder
()
let
message
=
async_openai
::
types
::
ChatCompletionRequestMessage
::
User
(
async_openai
::
types
::
ChatCompletionRequestUserMessage
{
content
:
async_openai
::
types
::
ChatCompletionRequestUserMessageContent
::
Text
(
"hi"
.to_string
(),
),
name
:
None
,
},
);
let
mut
request
=
async_openai
::
types
::
CreateChatCompletionRequestArgs
::
default
()
.model
(
"foo"
)
.
add_user_
message
(
"hi"
)
.
messages
(
vec!
[
message
]
)
.build
()
.unwrap
();
.expect
(
"Failed to build request"
);
// let mut request = ChatCompletionRequest::builder()
// .model("foo")
// .add_user_message("hi")
// .build()
// .unwrap();
// ==== ChatCompletions / Stream / Success ====
request
.stream
=
Some
(
true
);
// ALLOW: max_tokens is deprecated in favor of completion_usage_tokens
request
.max_tokens
=
Some
(
3000
);
let
response
=
client
...
...
@@ -293,6 +318,8 @@ async fn test_http_service() {
// ==== ChatCompletions / Unary / Success ====
request
.stream
=
Some
(
false
);
// ALLOW: max_tokens is deprecated in favor of completion_usage_tokens
request
.max_tokens
=
Some
(
0
);
let
future
=
client
...
...
@@ -315,6 +342,8 @@ async fn test_http_service() {
// ==== ChatCompletions / Stream / Error ====
request
.model
=
"bar"
.to_string
();
// ALLOW: max_tokens is deprecated in favor of completion_usage_tokens
request
.max_tokens
=
Some
(
0
);
request
.stream
=
Some
(
true
);
...
...
lib/llm/tests/openai_completions.rs
View file @
86aff237
...
...
@@ -136,7 +136,7 @@ fn build_samples() -> Result<Vec<CompletionSample>, String> {
samples
.push
(
CompletionSample
::
new
(
"should have prompt, model, and max_tokens fields"
,
|
builder
|
builder
.max_tokens
(
10
),
|
builder
|
builder
.max_tokens
(
10
_u32
),
)
?
);
samples
.push
(
CompletionSample
::
new
(
...
...
lib/llm/tests/preprocessor.rs
View file @
86aff237
...
...
@@ -18,9 +18,7 @@ use anyhow::Ok;
use
serde
::{
Deserialize
,
Serialize
};
use
triton_distributed_llm
::
model_card
::
model
::{
ModelDeploymentCard
,
PromptContextMixin
};
use
triton_distributed_llm
::
preprocessor
::
prompt
::
PromptFormatter
;
use
triton_distributed_llm
::
protocols
::
openai
::
chat_completions
::{
ChatCompletionMessage
,
ChatCompletionRequest
,
Tool
,
ToolChoiceType
,
};
use
triton_distributed_llm
::
protocols
::
openai
::
chat_completions
::
ChatCompletionRequest
;
use
hf_hub
::{
api
::
tokio
::
ApiBuilder
,
Cache
,
Repo
,
RepoType
};
...
...
@@ -217,29 +215,40 @@ const TOOLS: &str = r#"
]
"#
;
// Notes:
// protocols::openai::chat_completions::ChatCompletionMessage -> async_openai::types::ChatCompletionRequestMessage
// protocols::openai::chat_completions::Tool -> async_openai::types::ChatCompletionTool
// protocols::openai::chat_completions::ToolChoiceType -> async_openai::types::ChatCompletionToolChoiceOption
#[derive(Serialize,
Deserialize)]
struct
Request
{
messages
:
Vec
<
ChatCompletionMessage
>
,
tools
:
Option
<
Vec
<
Tool
>>
,
tool_choice
:
Option
<
ToolChoice
Type
>
,
messages
:
Vec
<
async_openai
::
types
::
ChatCompletion
Request
Message
>
,
tools
:
Option
<
Vec
<
async_openai
::
types
::
ChatCompletion
Tool
>>
,
tool_choice
:
Option
<
async_openai
::
types
::
ChatCompletion
ToolChoice
Option
>
,
}
impl
Request
{
fn
from
(
messages
:
&
str
,
tools
:
Option
<&
str
>
,
tool_choice
:
Option
<
ToolChoice
Type
>
,
tool_choice
:
Option
<
async_openai
::
types
::
ChatCompletion
ToolChoice
Option
>
,
model
:
String
,
)
->
ChatCompletionRequest
{
let
messages
:
Vec
<
ChatCompletionMessage
>
=
serde_json
::
from_str
(
messages
)
.unwrap
();
let
tools
:
Option
<
Vec
<
Tool
>>
=
tools
.map
(|
x
|
serde_json
::
from_str
(
x
)
.unwrap
());
ChatCompletionRequest
::
builder
()
let
messages
:
Vec
<
async_openai
::
types
::
ChatCompletionRequestMessage
>
=
serde_json
::
from_str
(
messages
)
.unwrap
();
let
tools
:
Option
<
Vec
<
async_openai
::
types
::
ChatCompletionTool
>>
=
tools
.map
(|
x
|
serde_json
::
from_str
(
x
)
.unwrap
());
let
tools
=
tools
.unwrap
();
let
tool_choice
=
tool_choice
.unwrap
();
let
inner
=
async_openai
::
types
::
CreateChatCompletionRequestArgs
::
default
()
.model
(
model
)
.messages
(
messages
)
.tools
(
tools
)
.tool_choice
(
tool_choice
)
.build
()
.unwrap
()
.unwrap
();
ChatCompletionRequest
{
inner
,
nvext
:
None
}
}
}
...
...
@@ -295,7 +304,7 @@ async fn test_single_turn_with_tools() {
let
request
=
Request
::
from
(
SINGLE_CHAT_MESSAGE
,
Some
(
TOOLS
),
Some
(
ToolChoice
Type
::
Auto
),
Some
(
async_openai
::
types
::
ChatCompletion
ToolChoice
Option
::
Auto
),
mdc
.slug
()
.to_string
(),
);
let
formatted_prompt
=
formatter
.render
(
&
request
)
.unwrap
();
...
...
@@ -402,7 +411,7 @@ async fn test_multi_turn_with_system_with_tools() {
let
request
=
Request
::
from
(
THREE_TURN_CHAT_MESSAGE_WITH_SYSTEM
,
Some
(
TOOLS
),
Some
(
ToolChoice
Type
::
Auto
),
Some
(
async_openai
::
types
::
ChatCompletion
ToolChoice
Option
::
Auto
),
mdc
.slug
()
.to_string
(),
);
let
formatted_prompt
=
formatter
.render
(
&
request
)
.unwrap
();
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment