Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
7b7b6a6d
Unverified
Commit
7b7b6a6d
authored
Jun 26, 2025
by
Paul Hendricks
Committed by
GitHub
Jun 26, 2025
Browse files
refactor: refactored using Choice and CompletionFinishReason (#1635)
parent
c95031ed
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
123 additions
and
88 deletions
+123
-88
lib/llm/src/engines.rs
lib/llm/src/engines.rs
+1
-1
lib/llm/src/protocols/common.rs
lib/llm/src/protocols/common.rs
+31
-0
lib/llm/src/protocols/openai/chat_completions/delta.rs
lib/llm/src/protocols/openai/chat_completions/delta.rs
+3
-0
lib/llm/src/protocols/openai/completions.rs
lib/llm/src/protocols/openai/completions.rs
+24
-44
lib/llm/src/protocols/openai/completions/aggregator.rs
lib/llm/src/protocols/openai/completions/aggregator.rs
+57
-27
lib/llm/src/protocols/openai/completions/delta.rs
lib/llm/src/protocols/openai/completions/delta.rs
+7
-16
No files found.
lib/llm/src/engines.rs
View file @
7b7b6a6d
...
...
@@ -238,7 +238,7 @@ impl AsyncEngine<SingleIn<NvCreateCompletionRequest>, ManyOut<Annotated<Completi
yield
Annotated
{
id
:
Some
(
id
.to_string
()),
data
:
Some
(
response
),
event
:
None
,
comment
:
None
};
id
+=
1
;
}
let
response
=
deltas
.create_choice
(
0
,
None
,
Some
(
"stop"
.to_string
()
));
let
response
=
deltas
.create_choice
(
0
,
None
,
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Stop
));
yield
Annotated
{
id
:
Some
(
id
.to_string
()),
data
:
Some
(
response
),
event
:
None
,
comment
:
None
};
};
...
...
lib/llm/src/protocols/common.rs
View file @
7b7b6a6d
...
...
@@ -64,6 +64,9 @@ pub enum FinishReason {
#[serde(rename
=
"cancelled"
)]
Cancelled
,
#[serde(rename
=
"content_filter"
)]
ContentFilter
,
}
impl
std
::
fmt
::
Display
for
FinishReason
{
...
...
@@ -74,6 +77,7 @@ impl std::fmt::Display for FinishReason {
FinishReason
::
Stop
=>
write!
(
f
,
"stop"
),
FinishReason
::
Error
(
msg
)
=>
write!
(
f
,
"error: {}"
,
msg
),
FinishReason
::
Cancelled
=>
write!
(
f
,
"cancelled"
),
FinishReason
::
ContentFilter
=>
write!
(
f
,
"content_filter"
),
}
}
}
...
...
@@ -93,6 +97,33 @@ impl std::str::FromStr for FinishReason {
}
}
impl
From
<
FinishReason
>
for
async_openai
::
types
::
CompletionFinishReason
{
fn
from
(
reason
:
FinishReason
)
->
Self
{
match
reason
{
FinishReason
::
EoS
|
FinishReason
::
Stop
|
FinishReason
::
Cancelled
=>
{
async_openai
::
types
::
CompletionFinishReason
::
Stop
}
FinishReason
::
ContentFilter
=>
{
async_openai
::
types
::
CompletionFinishReason
::
ContentFilter
}
FinishReason
::
Length
=>
async_openai
::
types
::
CompletionFinishReason
::
Length
,
FinishReason
::
Error
(
_
)
=>
async_openai
::
types
::
CompletionFinishReason
::
Stop
,
}
}
}
impl
From
<
async_openai
::
types
::
CompletionFinishReason
>
for
FinishReason
{
fn
from
(
reason
:
async_openai
::
types
::
CompletionFinishReason
)
->
Self
{
match
reason
{
async_openai
::
types
::
CompletionFinishReason
::
Stop
=>
FinishReason
::
Stop
,
async_openai
::
types
::
CompletionFinishReason
::
Length
=>
FinishReason
::
Length
,
async_openai
::
types
::
CompletionFinishReason
::
ContentFilter
=>
{
FinishReason
::
ContentFilter
}
}
}
}
/// LLM Inference Engines can accept a variety of input types. Not all Engines will support all
/// input types. For example, the trtllm::AsyncEngine only supports `PromptType::Tokens` as an
/// input type. The higher-level `Backend` class is a general wrapper around Engines that will
...
...
lib/llm/src/protocols/openai/chat_completions/delta.rs
View file @
7b7b6a6d
...
...
@@ -203,6 +203,9 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
Some
(
common
::
FinishReason
::
Stop
)
=>
Some
(
async_openai
::
types
::
FinishReason
::
Stop
),
Some
(
common
::
FinishReason
::
Length
)
=>
Some
(
async_openai
::
types
::
FinishReason
::
Length
),
Some
(
common
::
FinishReason
::
Cancelled
)
=>
Some
(
async_openai
::
types
::
FinishReason
::
Stop
),
Some
(
common
::
FinishReason
::
ContentFilter
)
=>
{
Some
(
async_openai
::
types
::
FinishReason
::
ContentFilter
)
}
Some
(
common
::
FinishReason
::
Error
(
err_msg
))
=>
{
return
Err
(
anyhow
::
anyhow!
(
err_msg
));
}
...
...
lib/llm/src/protocols/openai/completions.rs
View file @
7b7b6a6d
...
...
@@ -49,7 +49,7 @@ pub struct CompletionResponse {
pub
id
:
String
,
/// The list of completion choices the model generated for the input prompt.
pub
choices
:
Vec
<
Completion
Choice
>
,
pub
choices
:
Vec
<
async_openai
::
types
::
Choice
>
,
/// The Unix timestamp (in seconds) of when the completion was created.
pub
created
:
u64
,
...
...
@@ -76,35 +76,12 @@ pub struct CompletionResponse {
// pub nvext: Option<NimResponseExt>,
}
/// Legacy OpenAI CompletionResponse Choice component
#[derive(Clone,
Debug,
Deserialize,
Serialize,
Builder)]
pub
struct
CompletionChoice
{
#[builder(setter(into))]
pub
text
:
String
,
#[builder(default
=
"0"
)]
pub
index
:
u64
,
#[builder(default,
setter(into,
strip_option))]
pub
finish_reason
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default,
setter(strip_option))]
pub
logprobs
:
Option
<
async_openai
::
types
::
Logprobs
>
,
}
impl
ContentProvider
for
CompletionChoice
{
impl
ContentProvider
for
async_openai
::
types
::
Choice
{
fn
content
(
&
self
)
->
String
{
self
.text
.clone
()
}
}
impl
CompletionChoice
{
pub
fn
builder
()
->
CompletionChoiceBuilder
{
CompletionChoiceBuilder
::
default
()
}
}
pub
fn
prompt_to_string
(
prompt
:
&
async_openai
::
types
::
Prompt
)
->
String
{
match
prompt
{
async_openai
::
types
::
Prompt
::
String
(
s
)
=>
s
.clone
(),
...
...
@@ -226,7 +203,7 @@ impl ResponseFactory {
pub
fn
make_response
(
&
self
,
choice
:
Completion
Choice
,
choice
:
async_openai
::
types
::
Choice
,
usage
:
Option
<
async_openai
::
types
::
CompletionUsage
>
,
)
->
CompletionResponse
{
CompletionResponse
{
...
...
@@ -294,27 +271,30 @@ impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest {
}
}
impl
TryFrom
<
common
::
StreamingCompletionResponse
>
for
Completion
Choice
{
impl
TryFrom
<
common
::
StreamingCompletionResponse
>
for
async_openai
::
types
::
Choice
{
type
Error
=
anyhow
::
Error
;
fn
try_from
(
response
:
common
::
StreamingCompletionResponse
)
->
Result
<
Self
,
Self
::
Error
>
{
let
choice
=
CompletionChoice
{
text
:
response
.delta
.text
.ok_or
(
anyhow
::
anyhow!
(
"No text in response"
))
?
,
index
:
response
.delta.index
.unwrap_or
(
0
)
as
u64
,
logprobs
:
None
,
finish_reason
:
match
&
response
.delta.finish_reason
{
Some
(
common
::
FinishReason
::
EoS
)
=>
Some
(
"stop"
.to_string
()),
Some
(
common
::
FinishReason
::
Stop
)
=>
Some
(
"stop"
.to_string
()),
Some
(
common
::
FinishReason
::
Length
)
=>
Some
(
"length"
.to_string
()),
Some
(
common
::
FinishReason
::
Error
(
err_msg
))
=>
{
return
Err
(
anyhow
::
anyhow!
(
"finish_reason::error = {}"
,
err_msg
));
}
Some
(
common
::
FinishReason
::
Cancelled
)
=>
Some
(
"cancelled"
.to_string
()),
None
=>
None
,
},
let
text
=
response
.delta
.text
.ok_or
(
anyhow
::
anyhow!
(
"No text in response"
))
?
;
// Safety: we're downcasting from u64 to u32 here but u32::MAX is 4_294_967_295
// so we're fairly safe knowing we won't generate that many Choices
let
index
=
response
.delta.index
.unwrap_or
(
0
)
as
u32
;
// TODO handle aggregating logprobs
let
logprobs
=
None
;
let
finish_reason
:
Option
<
async_openai
::
types
::
CompletionFinishReason
>
=
response
.delta.finish_reason
.map
(
Into
::
into
);
let
choice
=
async_openai
::
types
::
Choice
{
text
,
index
,
logprobs
,
finish_reason
,
};
Ok
(
choice
)
...
...
lib/llm/src/protocols/openai/completions/aggregator.rs
View file @
7b7b6a6d
...
...
@@ -13,12 +13,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use
std
::
{
collections
::
HashMap
,
str
::
FromStr
}
;
use
std
::
collections
::
HashMap
;
use
anyhow
::
Result
;
use
futures
::
StreamExt
;
use
super
::
{
CompletionChoice
,
CompletionResponse
}
;
use
super
::
CompletionResponse
;
use
crate
::
protocols
::{
codec
::{
Message
,
SseCodecError
},
common
::
FinishReason
,
...
...
@@ -98,9 +98,9 @@ impl DeltaAggregator {
let
state_choice
=
aggregator
.choices
.entry
(
choice
.index
)
.entry
(
choice
.index
as
u64
)
.or_insert
(
DeltaChoice
{
index
:
choice
.index
,
index
:
choice
.index
as
u64
,
text
:
""
.to_string
(),
finish_reason
:
None
,
logprobs
:
choice
.logprobs
,
...
...
@@ -108,12 +108,21 @@ impl DeltaAggregator {
state_choice
.text
.push_str
(
&
choice
.text
);
// todo - handle logprobs
if
let
Some
(
finish_reason
)
=
choice
.finish_reason
{
let
reason
=
FinishReason
::
from_str
(
&
finish_reason
)
.ok
();
state_choice
.finish_reason
=
reason
;
}
// TODO - handle logprobs
// Handle CompletionFinishReason -> FinishReason conversation
state_choice
.finish_reason
=
match
choice
.finish_reason
{
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Stop
)
=>
{
Some
(
FinishReason
::
Stop
)
}
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Length
)
=>
{
Some
(
FinishReason
::
Length
)
}
Some
(
async_openai
::
types
::
CompletionFinishReason
::
ContentFilter
)
=>
{
Some
(
FinishReason
::
ContentFilter
)
}
None
=>
None
,
};
}
}
aggregator
...
...
@@ -131,7 +140,7 @@ impl DeltaAggregator {
let
mut
choices
:
Vec
<
_
>
=
aggregator
.choices
.into_values
()
.map
(
Completion
Choice
::
from
)
.map
(
async_openai
::
types
::
Choice
::
from
)
.collect
();
choices
.sort_by
(|
a
,
b
|
a
.index
.cmp
(
&
b
.index
));
...
...
@@ -148,12 +157,12 @@ impl DeltaAggregator {
}
}
impl
From
<
DeltaChoice
>
for
Completion
Choice
{
impl
From
<
DeltaChoice
>
for
async_openai
::
types
::
Choice
{
fn
from
(
delta
:
DeltaChoice
)
->
Self
{
let
finish_reason
=
delta
.finish_reason
.map
(
|
reason
|
reason
.to_string
()
);
let
finish_reason
=
delta
.finish_reason
.map
(
Into
::
into
);
Completion
Choice
{
index
:
delta
.index
,
async_openai
::
types
::
Choice
{
index
:
delta
.index
as
u32
,
text
:
delta
.text
,
finish_reason
,
logprobs
:
delta
.logprobs
,
...
...
@@ -178,16 +187,25 @@ impl CompletionResponse {
#[cfg(test)]
mod
tests
{
use
crate
::
protocols
::
openai
::
completions
::{
CompletionChoice
,
CompletionResponse
}
;
use
std
::
str
::
FromStr
;
use
super
::
*
;
use
futures
::
stream
;
use
super
::
*
;
use
crate
::
protocols
::
openai
::
completions
::
CompletionResponse
;
fn
create_test_delta
(
index
:
u64
,
text
:
&
str
,
finish_reason
:
Option
<
String
>
,
)
->
Annotated
<
CompletionResponse
>
{
// This will silently discard invalid_finish reason values and fall back
// to None - totally fine since this is test code
let
finish_reason
=
finish_reason
.as_deref
()
.and_then
(|
s
|
FinishReason
::
from_str
(
s
)
.ok
())
.map
(
Into
::
into
);
Annotated
{
data
:
Some
(
CompletionResponse
{
id
:
"test_id"
.to_string
(),
...
...
@@ -195,8 +213,8 @@ mod tests {
created
:
1234567890
,
usage
:
None
,
system_fingerprint
:
None
,
choices
:
vec!
[
Completion
Choice
{
index
,
choices
:
vec!
[
async_openai
::
types
::
Choice
{
index
:
index
as
u32
,
text
:
text
.to_string
(),
finish_reason
,
logprobs
:
None
,
...
...
@@ -255,7 +273,10 @@ mod tests {
let
choice
=
&
response
.choices
[
0
];
assert_eq!
(
choice
.index
,
0
);
assert_eq!
(
choice
.text
,
"Hello,"
.to_string
());
assert_eq!
(
choice
.finish_reason
,
Some
(
"length"
.to_string
()));
assert_eq!
(
choice
.finish_reason
,
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Length
)
);
assert
!
(
choice
.logprobs
.is_none
());
}
...
...
@@ -283,7 +304,10 @@ mod tests {
let
choice
=
&
response
.choices
[
0
];
assert_eq!
(
choice
.index
,
0
);
assert_eq!
(
choice
.text
,
"Hello, world!"
.to_string
());
assert_eq!
(
choice
.finish_reason
,
Some
(
"stop"
.to_string
()));
assert_eq!
(
choice
.finish_reason
,
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Stop
)
);
}
#[tokio::test]
...
...
@@ -297,16 +321,16 @@ mod tests {
usage
:
None
,
system_fingerprint
:
None
,
choices
:
vec!
[
Completion
Choice
{
async_openai
::
types
::
Choice
{
index
:
0
,
text
:
"Choice 0"
.to_string
(),
finish_reason
:
Some
(
"stop"
.to_string
()
),
finish_reason
:
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Stop
),
logprobs
:
None
,
},
Completion
Choice
{
async_openai
::
types
::
Choice
{
index
:
1
,
text
:
"Choice 1"
.to_string
(),
finish_reason
:
Some
(
"stop"
.to_string
()
),
finish_reason
:
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Stop
),
logprobs
:
None
,
},
],
...
...
@@ -333,11 +357,17 @@ mod tests {
let
choice0
=
&
response
.choices
[
0
];
assert_eq!
(
choice0
.index
,
0
);
assert_eq!
(
choice0
.text
,
"Choice 0"
.to_string
());
assert_eq!
(
choice0
.finish_reason
,
Some
(
"stop"
.to_string
()));
assert_eq!
(
choice0
.finish_reason
,
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Stop
)
);
let
choice1
=
&
response
.choices
[
1
];
assert_eq!
(
choice1
.index
,
1
);
assert_eq!
(
choice1
.text
,
"Choice 1"
.to_string
());
assert_eq!
(
choice1
.finish_reason
,
Some
(
"stop"
.to_string
()));
assert_eq!
(
choice1
.finish_reason
,
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Stop
)
);
}
}
lib/llm/src/protocols/openai/completions/delta.rs
View file @
7b7b6a6d
...
...
@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use
super
::{
CompletionChoice
,
CompletionResponse
,
NvCreateCompletionRequest
};
use
super
::{
CompletionResponse
,
NvCreateCompletionRequest
};
use
crate
::
protocols
::
common
;
impl
NvCreateCompletionRequest
{
...
...
@@ -82,7 +82,7 @@ impl DeltaGenerator {
&
self
,
index
:
u64
,
text
:
Option
<
String
>
,
finish_reason
:
Option
<
String
>
,
finish_reason
:
Option
<
async_openai
::
types
::
CompletionFinishReason
>
,
)
->
CompletionResponse
{
// todo - update for tool calling
...
...
@@ -97,9 +97,9 @@ impl DeltaGenerator {
created
:
self
.created
,
model
:
self
.model
.clone
(),
system_fingerprint
:
self
.system_fingerprint
.clone
(),
choices
:
vec!
[
Completion
Choice
{
choices
:
vec!
[
async_openai
::
types
::
Choice
{
text
:
text
.unwrap_or_default
(),
index
,
index
:
index
as
u32
,
finish_reason
,
logprobs
:
None
,
}],
...
...
@@ -122,18 +122,9 @@ impl crate::protocols::openai::DeltaGeneratorExt<CompletionResponse> for DeltaGe
self
.usage.completion_tokens
+=
delta
.token_ids
.len
()
as
u32
;
}
// todo logprobs
let
finish_reason
=
match
delta
.finish_reason
{
Some
(
common
::
FinishReason
::
EoS
)
=>
Some
(
"stop"
.to_string
()),
Some
(
common
::
FinishReason
::
Stop
)
=>
Some
(
"stop"
.to_string
()),
Some
(
common
::
FinishReason
::
Length
)
=>
Some
(
"length"
.to_string
()),
Some
(
common
::
FinishReason
::
Cancelled
)
=>
Some
(
"cancelled"
.to_string
()),
Some
(
common
::
FinishReason
::
Error
(
err_msg
))
=>
{
return
Err
(
anyhow
::
anyhow!
(
err_msg
));
}
None
=>
None
,
};
// TODO logprobs
let
finish_reason
=
delta
.finish_reason
.map
(
Into
::
into
);
// create choice
let
index
=
delta
.index
.unwrap_or
(
0
)
.into
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment