Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
f476fd74
Unverified
Commit
f476fd74
authored
Aug 14, 2025
by
Greg Clark
Committed by
GitHub
Aug 14, 2025
Browse files
feat: logprob handling (#2426)
Signed-off-by:
Greg Clark
<
grclark@nvidia.com
>
parent
5816c082
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
276 additions
and
20 deletions
+276
-20
lib/engines/llamacpp/src/lib.rs
lib/engines/llamacpp/src/lib.rs
+1
-0
lib/engines/mistralrs/src/lib.rs
lib/engines/mistralrs/src/lib.rs
+1
-1
lib/llm/src/backend.rs
lib/llm/src/backend.rs
+1
-0
lib/llm/src/engines.rs
lib/llm/src/engines.rs
+3
-2
lib/llm/src/migration.rs
lib/llm/src/migration.rs
+3
-1
lib/llm/src/mocker/engine.rs
lib/llm/src/mocker/engine.rs
+3
-1
lib/llm/src/preprocessor.rs
lib/llm/src/preprocessor.rs
+3
-1
lib/llm/src/protocols/common.rs
lib/llm/src/protocols/common.rs
+7
-0
lib/llm/src/protocols/common/llm_backend.rs
lib/llm/src/protocols/common/llm_backend.rs
+17
-0
lib/llm/src/protocols/common/preprocessor.rs
lib/llm/src/protocols/common/preprocessor.rs
+5
-1
lib/llm/src/protocols/openai.rs
lib/llm/src/protocols/openai.rs
+27
-1
lib/llm/src/protocols/openai/chat_completions.rs
lib/llm/src/protocols/openai/chat_completions.rs
+27
-1
lib/llm/src/protocols/openai/chat_completions/delta.rs
lib/llm/src/protocols/openai/chat_completions/delta.rs
+77
-4
lib/llm/src/protocols/openai/completions.rs
lib/llm/src/protocols/openai/completions.rs
+28
-2
lib/llm/src/protocols/openai/completions/delta.rs
lib/llm/src/protocols/openai/completions/delta.rs
+73
-5
No files found.
lib/engines/llamacpp/src/lib.rs
View file @
f476fd74
...
@@ -268,6 +268,7 @@ fn run_request(
...
@@ -268,6 +268,7 @@ fn run_request(
//text: if output.text.is_empty() { None } else { Some(output.text) },
//text: if output.text.is_empty() { None } else { Some(output.text) },
cum_log_probs
:
None
,
// TODO output.cumulative_logprob.map(|v| v as f64),
cum_log_probs
:
None
,
// TODO output.cumulative_logprob.map(|v| v as f64),
log_probs
:
None
,
// TODO output.logprobs
log_probs
:
None
,
// TODO output.logprobs
top_logprobs
:
None
,
finish_reason
:
None
,
finish_reason
:
None
,
index
:
None
,
index
:
None
,
};
};
...
...
lib/engines/mistralrs/src/lib.rs
View file @
f476fd74
...
@@ -590,7 +590,7 @@ impl
...
@@ -590,7 +590,7 @@ impl
None
=>
None
,
None
=>
None
,
};
};
#[allow(deprecated)]
#[allow(deprecated)]
let
inner
=
response_generator
.create_choice
(
0
,
Some
(
from_assistant
),
None
);
let
inner
=
response_generator
.create_choice
(
0
,
Some
(
from_assistant
),
None
,
None
);
let
ann
=
Annotated
{
let
ann
=
Annotated
{
id
:
None
,
id
:
None
,
data
:
Some
(
inner
),
data
:
Some
(
inner
),
...
...
lib/llm/src/backend.rs
View file @
f476fd74
...
@@ -231,6 +231,7 @@ impl
...
@@ -231,6 +231,7 @@ impl
text
:
data
.text
,
text
:
data
.text
,
cum_log_probs
:
data
.cum_log_probs
,
cum_log_probs
:
data
.cum_log_probs
,
log_probs
:
data
.log_probs
,
log_probs
:
data
.log_probs
,
top_logprobs
:
data
.top_logprobs
,
finish_reason
:
data
.finish_reason
,
finish_reason
:
data
.finish_reason
,
//mdcsum: mdcsum.clone(),
//mdcsum: mdcsum.clone(),
index
:
data
.index
,
index
:
data
.index
,
...
...
lib/llm/src/engines.rs
View file @
f476fd74
...
@@ -102,6 +102,7 @@ fn delta_core(tok: u32) -> Annotated<LLMEngineOutput> {
...
@@ -102,6 +102,7 @@ fn delta_core(tok: u32) -> Annotated<LLMEngineOutput> {
text
:
None
,
text
:
None
,
cum_log_probs
:
None
,
cum_log_probs
:
None
,
log_probs
:
None
,
log_probs
:
None
,
top_logprobs
:
None
,
finish_reason
:
None
,
finish_reason
:
None
,
index
:
None
,
index
:
None
,
};
};
...
@@ -242,11 +243,11 @@ impl
...
@@ -242,11 +243,11 @@ impl
let
mut
id
=
1
;
let
mut
id
=
1
;
for
c
in
chars_string
.chars
()
{
for
c
in
chars_string
.chars
()
{
tokio
::
time
::
sleep
(
*
TOKEN_ECHO_DELAY
)
.await
;
tokio
::
time
::
sleep
(
*
TOKEN_ECHO_DELAY
)
.await
;
let
response
=
deltas
.create_choice
(
0
,
Some
(
c
.to_string
()),
None
);
let
response
=
deltas
.create_choice
(
0
,
Some
(
c
.to_string
()),
None
,
None
);
yield
Annotated
{
id
:
Some
(
id
.to_string
()),
data
:
Some
(
response
),
event
:
None
,
comment
:
None
};
yield
Annotated
{
id
:
Some
(
id
.to_string
()),
data
:
Some
(
response
),
event
:
None
,
comment
:
None
};
id
+=
1
;
id
+=
1
;
}
}
let
response
=
deltas
.create_choice
(
0
,
None
,
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Stop
));
let
response
=
deltas
.create_choice
(
0
,
None
,
Some
(
async_openai
::
types
::
CompletionFinishReason
::
Stop
)
,
None
);
yield
Annotated
{
id
:
Some
(
id
.to_string
()),
data
:
Some
(
response
),
event
:
None
,
comment
:
None
};
yield
Annotated
{
id
:
Some
(
id
.to_string
()),
data
:
Some
(
response
),
event
:
None
,
comment
:
None
};
};
};
...
...
lib/llm/src/migration.rs
View file @
f476fd74
...
@@ -166,7 +166,7 @@ impl RetryManager {
...
@@ -166,7 +166,7 @@ impl RetryManager {
#[cfg(test)]
#[cfg(test)]
mod
tests
{
mod
tests
{
use
super
::
*
;
use
super
::
*
;
use
crate
::
protocols
::
common
::{
SamplingOptions
,
StopConditions
};
use
crate
::
protocols
::
common
::{
OutputOptions
,
SamplingOptions
,
StopConditions
};
use
dynamo_runtime
::
pipeline
::
context
::
Controller
;
use
dynamo_runtime
::
pipeline
::
context
::
Controller
;
use
dynamo_runtime
::
pipeline
::
AsyncEngine
;
use
dynamo_runtime
::
pipeline
::
AsyncEngine
;
use
std
::
sync
::
atomic
::{
AtomicU32
,
Ordering
};
use
std
::
sync
::
atomic
::{
AtomicU32
,
Ordering
};
...
@@ -183,6 +183,7 @@ mod tests {
...
@@ -183,6 +183,7 @@ mod tests {
..
Default
::
default
()
..
Default
::
default
()
},
},
sampling_options
:
SamplingOptions
::
default
(),
sampling_options
:
SamplingOptions
::
default
(),
output_options
:
OutputOptions
::
default
(),
eos_token_ids
:
vec!
[],
eos_token_ids
:
vec!
[],
mdc_sum
:
None
,
mdc_sum
:
None
,
annotations
:
vec!
[],
annotations
:
vec!
[],
...
@@ -198,6 +199,7 @@ mod tests {
...
@@ -198,6 +199,7 @@ mod tests {
text
:
Some
(
format!
(
"token_{}"
,
token_id
)),
text
:
Some
(
format!
(
"token_{}"
,
token_id
)),
cum_log_probs
:
None
,
cum_log_probs
:
None
,
log_probs
:
None
,
log_probs
:
None
,
top_logprobs
:
None
,
finish_reason
:
None
,
finish_reason
:
None
,
index
:
None
,
index
:
None
,
})
})
...
...
lib/llm/src/mocker/engine.rs
View file @
f476fd74
...
@@ -405,6 +405,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
...
@@ -405,6 +405,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
text
:
None
,
text
:
None
,
cum_log_probs
:
None
,
cum_log_probs
:
None
,
log_probs
:
None
,
log_probs
:
None
,
top_logprobs
:
None
,
finish_reason
:
None
,
finish_reason
:
None
,
index
:
None
,
index
:
None
,
};
};
...
@@ -525,7 +526,7 @@ mod integration_tests {
...
@@ -525,7 +526,7 @@ mod integration_tests {
use
super
::
*
;
use
super
::
*
;
use
crate
::
kv_router
::
indexer
::
RouterEvent
;
use
crate
::
kv_router
::
indexer
::
RouterEvent
;
use
crate
::
kv_router
::
KV_EVENT_SUBJECT
;
use
crate
::
kv_router
::
KV_EVENT_SUBJECT
;
use
crate
::
protocols
::
common
::{
SamplingOptions
,
StopConditions
};
use
crate
::
protocols
::
common
::{
OutputOptions
,
SamplingOptions
,
StopConditions
};
use
dynamo_runtime
::{
use
dynamo_runtime
::{
pipeline
::
Context
,
pipeline
::
Context
,
pipeline
::{
network
::
Ingress
,
PushRouter
},
pipeline
::{
network
::
Ingress
,
PushRouter
},
...
@@ -641,6 +642,7 @@ mod integration_tests {
...
@@ -641,6 +642,7 @@ mod integration_tests {
..
Default
::
default
()
..
Default
::
default
()
},
},
sampling_options
:
SamplingOptions
::
default
(),
sampling_options
:
SamplingOptions
::
default
(),
output_options
:
OutputOptions
::
default
(),
eos_token_ids
:
vec!
[],
eos_token_ids
:
vec!
[],
mdc_sum
:
None
,
mdc_sum
:
None
,
annotations
:
vec!
[
format!
(
"dp_rank:{dp_rank}"
)],
annotations
:
vec!
[
format!
(
"dp_rank:{dp_rank}"
)],
...
...
lib/llm/src/preprocessor.rs
View file @
f476fd74
...
@@ -33,7 +33,7 @@ use dynamo_runtime::pipeline::{
...
@@ -33,7 +33,7 @@ use dynamo_runtime::pipeline::{
use
dynamo_runtime
::
protocols
::
annotated
::{
Annotated
,
AnnotationsProvider
};
use
dynamo_runtime
::
protocols
::
annotated
::{
Annotated
,
AnnotationsProvider
};
use
crate
::
protocols
::{
use
crate
::
protocols
::{
common
::{
SamplingOptionsProvider
,
StopConditionsProvider
},
common
::{
OutputOptionsProvider
,
SamplingOptionsProvider
,
StopConditionsProvider
},
openai
::{
openai
::{
chat_completions
::{
NvCreateChatCompletionRequest
,
NvCreateChatCompletionStreamResponse
},
chat_completions
::{
NvCreateChatCompletionRequest
,
NvCreateChatCompletionStreamResponse
},
completions
::{
NvCreateCompletionRequest
,
NvCreateCompletionResponse
},
completions
::{
NvCreateCompletionRequest
,
NvCreateCompletionResponse
},
...
@@ -146,6 +146,7 @@ impl OpenAIPreprocessor {
...
@@ -146,6 +146,7 @@ impl OpenAIPreprocessor {
+
AnnotationsProvider
+
AnnotationsProvider
+
SamplingOptionsProvider
+
SamplingOptionsProvider
+
StopConditionsProvider
+
StopConditionsProvider
+
OutputOptionsProvider
+
NvExtProvider
,
+
NvExtProvider
,
>
(
>
(
&
self
,
&
self
,
...
@@ -249,6 +250,7 @@ impl OpenAIPreprocessor {
...
@@ -249,6 +250,7 @@ impl OpenAIPreprocessor {
builder
.stop_conditions
(
stop_conditions
);
builder
.stop_conditions
(
stop_conditions
);
builder
.sampling_options
(
request
.extract_sampling_options
()
?
);
builder
.sampling_options
(
request
.extract_sampling_options
()
?
);
builder
.output_options
(
request
.extract_output_options
()
?
);
builder
.annotations
(
request
.annotations
()
.unwrap_or_default
());
builder
.annotations
(
request
.annotations
()
.unwrap_or_default
());
builder
.mdc_sum
(
Some
(
self
.mdcsum
.clone
()));
builder
.mdc_sum
(
Some
(
self
.mdcsum
.clone
()));
builder
.estimated_prefix_hit_num_blocks
(
None
);
builder
.estimated_prefix_hit_num_blocks
(
None
);
...
...
lib/llm/src/protocols/common.rs
View file @
f476fd74
...
@@ -45,6 +45,10 @@ pub trait StopConditionsProvider {
...
@@ -45,6 +45,10 @@ pub trait StopConditionsProvider {
fn
extract_stop_conditions
(
&
self
)
->
Result
<
StopConditions
>
;
fn
extract_stop_conditions
(
&
self
)
->
Result
<
StopConditions
>
;
}
}
pub
trait
OutputOptionsProvider
{
fn
extract_output_options
(
&
self
)
->
Result
<
OutputOptions
>
;
}
#[derive(Serialize,
Deserialize,
Debug,
Clone,
PartialEq,
Eq)]
#[derive(Serialize,
Deserialize,
Debug,
Clone,
PartialEq,
Eq)]
pub
enum
FinishReason
{
pub
enum
FinishReason
{
#[serde(rename
=
"eos"
)]
#[serde(rename
=
"eos"
)]
...
@@ -179,6 +183,9 @@ pub struct CompletionRequest {
...
@@ -179,6 +183,9 @@ pub struct CompletionRequest {
/// are needed.
/// are needed.
pub
sampling_options
:
SamplingOptions
,
pub
sampling_options
:
SamplingOptions
,
#[builder(default)]
pub
output_options
:
OutputOptions
,
/// The computed checksum of the Model Deployment Card (MDC).
/// The computed checksum of the Model Deployment Card (MDC).
#[builder(default)]
#[builder(default)]
pub
mdc_sum
:
Option
<
String
>
,
pub
mdc_sum
:
Option
<
String
>
,
...
...
lib/llm/src/protocols/common/llm_backend.rs
View file @
f476fd74
...
@@ -23,6 +23,15 @@ use dynamo_runtime::protocols::maybe_error::MaybeError;
...
@@ -23,6 +23,15 @@ use dynamo_runtime::protocols::maybe_error::MaybeError;
pub
type
TokenType
=
Option
<
String
>
;
pub
type
TokenType
=
Option
<
String
>
;
pub
type
LogProbs
=
Vec
<
f64
>
;
pub
type
LogProbs
=
Vec
<
f64
>
;
#[derive(Serialize,
Deserialize,
Debug,
Clone,
PartialEq)]
pub
struct
TopLogprob
{
pub
rank
:
u32
,
pub
token_id
:
TokenIdType
,
pub
token
:
TokenType
,
pub
logprob
:
f64
,
}
pub
type
TopLogprobs
=
Vec
<
Vec
<
TopLogprob
>>
;
// num_tokens x top_logprobs
#[derive(Serialize,
Deserialize,
Debug,
Clone,
PartialEq)]
#[derive(Serialize,
Deserialize,
Debug,
Clone,
PartialEq)]
pub
struct
BackendOutput
{
pub
struct
BackendOutput
{
/// New token_ids generated from the LLM Engine
/// New token_ids generated from the LLM Engine
...
@@ -41,6 +50,8 @@ pub struct BackendOutput {
...
@@ -41,6 +50,8 @@ pub struct BackendOutput {
/// Optional log probabilities
/// Optional log probabilities
pub
log_probs
:
Option
<
LogProbs
>
,
pub
log_probs
:
Option
<
LogProbs
>
,
pub
top_logprobs
:
Option
<
TopLogprobs
>
,
// TODO: Enrich this with more information as can apply our first-level postprocessing
// TODO: Enrich this with more information as can apply our first-level postprocessing
// logic and return more detailed information
// logic and return more detailed information
pub
finish_reason
:
Option
<
FinishReason
>
,
pub
finish_reason
:
Option
<
FinishReason
>
,
...
@@ -77,6 +88,8 @@ pub struct LLMEngineOutput {
...
@@ -77,6 +88,8 @@ pub struct LLMEngineOutput {
/// Optional log probabilities
/// Optional log probabilities
pub
log_probs
:
Option
<
LogProbs
>
,
pub
log_probs
:
Option
<
LogProbs
>
,
pub
top_logprobs
:
Option
<
TopLogprobs
>
,
// TODO: Enrich this with more information as can apply our first-level postprocessing
// TODO: Enrich this with more information as can apply our first-level postprocessing
// logic and return more detailed information
// logic and return more detailed information
pub
finish_reason
:
Option
<
FinishReason
>
,
pub
finish_reason
:
Option
<
FinishReason
>
,
...
@@ -93,6 +106,7 @@ impl LLMEngineOutput {
...
@@ -93,6 +106,7 @@ impl LLMEngineOutput {
text
:
None
,
text
:
None
,
cum_log_probs
:
None
,
cum_log_probs
:
None
,
log_probs
:
None
,
log_probs
:
None
,
top_logprobs
:
None
,
finish_reason
:
Some
(
FinishReason
::
Cancelled
),
finish_reason
:
Some
(
FinishReason
::
Cancelled
),
index
:
None
,
index
:
None
,
}
}
...
@@ -106,6 +120,7 @@ impl LLMEngineOutput {
...
@@ -106,6 +120,7 @@ impl LLMEngineOutput {
cum_log_probs
:
None
,
cum_log_probs
:
None
,
log_probs
:
None
,
log_probs
:
None
,
finish_reason
:
Some
(
FinishReason
::
Stop
),
finish_reason
:
Some
(
FinishReason
::
Stop
),
top_logprobs
:
None
,
index
:
None
,
index
:
None
,
}
}
}
}
...
@@ -117,6 +132,7 @@ impl LLMEngineOutput {
...
@@ -117,6 +132,7 @@ impl LLMEngineOutput {
text
:
None
,
text
:
None
,
cum_log_probs
:
None
,
cum_log_probs
:
None
,
log_probs
:
None
,
log_probs
:
None
,
top_logprobs
:
None
,
finish_reason
:
Some
(
FinishReason
::
Length
),
finish_reason
:
Some
(
FinishReason
::
Length
),
index
:
None
,
index
:
None
,
}
}
...
@@ -129,6 +145,7 @@ impl LLMEngineOutput {
...
@@ -129,6 +145,7 @@ impl LLMEngineOutput {
text
:
None
,
text
:
None
,
cum_log_probs
:
None
,
cum_log_probs
:
None
,
log_probs
:
None
,
log_probs
:
None
,
top_logprobs
:
None
,
finish_reason
:
Some
(
FinishReason
::
Error
(
err_msg
)),
finish_reason
:
Some
(
FinishReason
::
Error
(
err_msg
)),
index
:
None
,
index
:
None
,
}
}
...
...
lib/llm/src/protocols/common/preprocessor.rs
View file @
f476fd74
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
use
derive_builder
::
Builder
;
use
derive_builder
::
Builder
;
use
serde
::{
Deserialize
,
Serialize
};
use
serde
::{
Deserialize
,
Serialize
};
use
super
::{
SamplingOptions
,
StopConditions
};
use
super
::{
OutputOptions
,
SamplingOptions
,
StopConditions
};
use
crate
::
protocols
::
TokenIdType
;
use
crate
::
protocols
::
TokenIdType
;
/// [`PreprocessedRequest`] is the internal representation of an LLM request. The [`dynamo.llm-preprocessor`]
/// [`PreprocessedRequest`] is the internal representation of an LLM request. The [`dynamo.llm-preprocessor`]
...
@@ -29,6 +29,10 @@ pub struct PreprocessedRequest {
...
@@ -29,6 +29,10 @@ pub struct PreprocessedRequest {
/// are needed.
/// are needed.
pub
sampling_options
:
SamplingOptions
,
pub
sampling_options
:
SamplingOptions
,
/// OutputOptions are options that control the output of the inference engine such as whether
/// to return log probabilities, or whether to skip special tokens in output.
pub
output_options
:
OutputOptions
,
/// The EOS token ID(s) for the Model
/// The EOS token ID(s) for the Model
/// Not every backend needs this, but those that do can find it here.
/// Not every backend needs this, but those that do can find it here.
/// TODO - refactor this to a better location
/// TODO - refactor this to a better location
...
...
lib/llm/src/protocols/openai.rs
View file @
f476fd74
...
@@ -17,7 +17,7 @@ use anyhow::Result;
...
@@ -17,7 +17,7 @@ use anyhow::Result;
use
serde
::{
Deserialize
,
Serialize
};
use
serde
::{
Deserialize
,
Serialize
};
use
super
::{
use
super
::{
common
::{
self
,
SamplingOptionsProvider
,
StopConditionsProvider
},
common
::{
self
,
OutputOptionsProvider
,
SamplingOptionsProvider
,
StopConditionsProvider
},
ContentProvider
,
ContentProvider
,
};
};
use
crate
::
protocols
::
openai
::
common_ext
::
CommonExtProvider
;
use
crate
::
protocols
::
openai
::
common_ext
::
CommonExtProvider
;
...
@@ -79,6 +79,16 @@ trait OpenAIStopConditionsProvider {
...
@@ -79,6 +79,16 @@ trait OpenAIStopConditionsProvider {
}
}
}
}
trait
OpenAIOutputOptionsProvider
{
fn
get_logprobs
(
&
self
)
->
Option
<
u32
>
;
fn
get_prompt_logprobs
(
&
self
)
->
Option
<
u32
>
;
fn
get_skip_special_tokens
(
&
self
)
->
Option
<
bool
>
;
fn
get_formatted_prompt
(
&
self
)
->
Option
<
bool
>
;
}
impl
<
T
:
OpenAISamplingOptionsProvider
+
CommonExtProvider
>
SamplingOptionsProvider
for
T
{
impl
<
T
:
OpenAISamplingOptionsProvider
+
CommonExtProvider
>
SamplingOptionsProvider
for
T
{
fn
extract_sampling_options
(
&
self
)
->
Result
<
common
::
SamplingOptions
>
{
fn
extract_sampling_options
(
&
self
)
->
Result
<
common
::
SamplingOptions
>
{
// let result = self.validate();
// let result = self.validate();
...
@@ -168,6 +178,22 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
...
@@ -168,6 +178,22 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
}
}
}
}
impl
<
T
:
OpenAIOutputOptionsProvider
>
OutputOptionsProvider
for
T
{
fn
extract_output_options
(
&
self
)
->
Result
<
common
::
OutputOptions
>
{
let
logprobs
=
self
.get_logprobs
();
let
prompt_logprobs
=
self
.get_prompt_logprobs
();
let
skip_special_tokens
=
self
.get_skip_special_tokens
();
let
formatted_prompt
=
self
.get_formatted_prompt
();
Ok
(
common
::
OutputOptions
{
logprobs
,
prompt_logprobs
,
skip_special_tokens
,
formatted_prompt
,
})
}
}
pub
trait
DeltaGeneratorExt
<
ResponseType
:
Send
+
Sync
+
'static
+
std
::
fmt
::
Debug
>
:
pub
trait
DeltaGeneratorExt
<
ResponseType
:
Send
+
Sync
+
'static
+
std
::
fmt
::
Debug
>
:
Send
+
Sync
+
'static
Send
+
Sync
+
'static
{
{
...
...
lib/llm/src/protocols/openai/chat_completions.rs
View file @
f476fd74
...
@@ -23,7 +23,8 @@ use super::{
...
@@ -23,7 +23,8 @@ use super::{
common_ext
::{
CommonExt
,
CommonExtProvider
},
common_ext
::{
CommonExt
,
CommonExtProvider
},
nvext
::
NvExt
,
nvext
::
NvExt
,
nvext
::
NvExtProvider
,
nvext
::
NvExtProvider
,
validate
,
OpenAISamplingOptionsProvider
,
OpenAIStopConditionsProvider
,
validate
,
OpenAIOutputOptionsProvider
,
OpenAISamplingOptionsProvider
,
OpenAIStopConditionsProvider
,
};
};
mod
aggregator
;
mod
aggregator
;
...
@@ -232,6 +233,31 @@ impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest {
...
@@ -232,6 +233,31 @@ impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest {
}
}
}
}
impl
OpenAIOutputOptionsProvider
for
NvCreateChatCompletionRequest
{
fn
get_logprobs
(
&
self
)
->
Option
<
u32
>
{
match
self
.inner.logprobs
{
Some
(
true
)
=>
match
self
.inner.top_logprobs
{
Some
(
top_logprobs
)
=>
Some
(
top_logprobs
as
u32
),
None
=>
Some
(
1_u32
),
},
Some
(
false
)
=>
None
,
None
=>
None
,
}
}
fn
get_prompt_logprobs
(
&
self
)
->
Option
<
u32
>
{
None
}
fn
get_skip_special_tokens
(
&
self
)
->
Option
<
bool
>
{
None
}
fn
get_formatted_prompt
(
&
self
)
->
Option
<
bool
>
{
None
}
}
/// Implements `ValidateRequest` for `NvCreateChatCompletionRequest`,
/// Implements `ValidateRequest` for `NvCreateChatCompletionRequest`,
/// allowing us to validate the data.
/// allowing us to validate the data.
impl
ValidateRequest
for
NvCreateChatCompletionRequest
{
impl
ValidateRequest
for
NvCreateChatCompletionRequest
{
...
...
lib/llm/src/protocols/openai/chat_completions/delta.rs
View file @
f476fd74
...
@@ -14,7 +14,10 @@
...
@@ -14,7 +14,10 @@
// limitations under the License.
// limitations under the License.
use
super
::{
NvCreateChatCompletionRequest
,
NvCreateChatCompletionStreamResponse
};
use
super
::{
NvCreateChatCompletionRequest
,
NvCreateChatCompletionStreamResponse
};
use
crate
::
protocols
::
common
;
use
crate
::{
protocols
::
common
::{
self
},
types
::
TokenIdType
,
};
/// Provides a method for generating a [`DeltaGenerator`] from a chat completion request.
/// Provides a method for generating a [`DeltaGenerator`] from a chat completion request.
impl
NvCreateChatCompletionRequest
{
impl
NvCreateChatCompletionRequest
{
...
@@ -25,7 +28,8 @@ impl NvCreateChatCompletionRequest {
...
@@ -25,7 +28,8 @@ impl NvCreateChatCompletionRequest {
pub
fn
response_generator
(
&
self
)
->
DeltaGenerator
{
pub
fn
response_generator
(
&
self
)
->
DeltaGenerator
{
let
options
=
DeltaGeneratorOptions
{
let
options
=
DeltaGeneratorOptions
{
enable_usage
:
true
,
enable_usage
:
true
,
enable_logprobs
:
self
.inner.logprobs
.unwrap_or
(
false
),
enable_logprobs
:
self
.inner.logprobs
.unwrap_or
(
false
)
||
self
.inner.top_logprobs
.unwrap_or
(
0
)
>
0
,
};
};
DeltaGenerator
::
new
(
self
.inner.model
.clone
(),
options
)
DeltaGenerator
::
new
(
self
.inner.model
.clone
(),
options
)
...
@@ -112,6 +116,71 @@ impl DeltaGenerator {
...
@@ -112,6 +116,71 @@ impl DeltaGenerator {
self
.usage.prompt_tokens
=
isl
;
self
.usage.prompt_tokens
=
isl
;
}
}
pub
fn
create_logprobs
(
&
self
,
tokens
:
Vec
<
common
::
llm_backend
::
TokenType
>
,
token_ids
:
Vec
<
TokenIdType
>
,
logprobs
:
Option
<
common
::
llm_backend
::
LogProbs
>
,
top_logprobs
:
Option
<
common
::
llm_backend
::
TopLogprobs
>
,
)
->
Option
<
async_openai
::
types
::
ChatChoiceLogprobs
>
{
if
!
self
.options.enable_logprobs
||
logprobs
.is_none
()
{
return
None
;
}
let
toks
=
tokens
.into_iter
()
.zip
(
token_ids
)
.map
(|(
token
,
token_id
)|
(
token
.unwrap_or_default
(),
token_id
))
.collect
::
<
Vec
<
(
String
,
TokenIdType
)
>>
();
let
tok_lps
=
toks
.iter
()
.zip
(
logprobs
.unwrap
())
.map
(|(
_
,
lp
)|
lp
as
f32
)
.collect
::
<
Vec
<
f32
>>
();
let
content
=
top_logprobs
.map
(|
top_logprobs
|
{
toks
.iter
()
.zip
(
tok_lps
)
.zip
(
top_logprobs
)
.map
(|(((
t
,
tid
),
lp
),
top_lps
)|
{
let
mut
found_selected_token
=
false
;
let
mut
converted_top_lps
=
top_lps
.iter
()
.map
(|
top_lp
|
{
let
top_t
=
top_lp
.token
.clone
()
.unwrap_or_default
();
let
top_tid
=
top_lp
.token_id
;
found_selected_token
=
found_selected_token
||
top_tid
==
*
tid
;
async_openai
::
types
::
TopLogprobs
{
token
:
top_t
,
logprob
:
top_lp
.logprob
as
f32
,
bytes
:
None
,
}
})
.collect
::
<
Vec
<
async_openai
::
types
::
TopLogprobs
>>
();
if
!
found_selected_token
{
// If the selected token is not in the top logprobs, add it
converted_top_lps
.push
(
async_openai
::
types
::
TopLogprobs
{
token
:
t
.clone
(),
logprob
:
lp
,
bytes
:
None
,
});
}
async_openai
::
types
::
ChatCompletionTokenLogprob
{
token
:
t
.clone
(),
logprob
:
lp
,
bytes
:
None
,
top_logprobs
:
converted_top_lps
,
}
})
.collect
()
});
Some
(
async_openai
::
types
::
ChatChoiceLogprobs
{
content
,
refusal
:
None
,
})
}
/// Creates a choice within a chat completion response.
/// Creates a choice within a chat completion response.
///
///
/// # Arguments
/// # Arguments
...
@@ -203,8 +272,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
...
@@ -203,8 +272,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
self
.usage.completion_tokens
+=
token_length
;
self
.usage.completion_tokens
+=
token_length
;
}
}
// TODO: Implement log probabilities aggregation.
let
logprobs
=
self
.create_logprobs
(
let
logprobs
=
None
;
delta
.tokens
,
delta
.token_ids
,
delta
.log_probs
,
delta
.top_logprobs
,
);
// Map backend finish reasons to OpenAI's finish reasons.
// Map backend finish reasons to OpenAI's finish reasons.
let
finish_reason
=
match
delta
.finish_reason
{
let
finish_reason
=
match
delta
.finish_reason
{
...
...
lib/llm/src/protocols/openai/completions.rs
View file @
f476fd74
...
@@ -21,10 +21,11 @@ use validator::Validate;
...
@@ -21,10 +21,11 @@ use validator::Validate;
use
crate
::
engines
::
ValidateRequest
;
use
crate
::
engines
::
ValidateRequest
;
use
super
::{
use
super
::{
common
::{
self
,
SamplingOptionsProvider
,
StopConditionsProvider
},
common
::{
self
,
OutputOptionsProvider
,
SamplingOptionsProvider
,
StopConditionsProvider
},
common_ext
::{
CommonExt
,
CommonExtProvider
},
common_ext
::{
CommonExt
,
CommonExtProvider
},
nvext
::{
NvExt
,
NvExtProvider
},
nvext
::{
NvExt
,
NvExtProvider
},
validate
,
ContentProvider
,
OpenAISamplingOptionsProvider
,
OpenAIStopConditionsProvider
,
validate
,
ContentProvider
,
OpenAIOutputOptionsProvider
,
OpenAISamplingOptionsProvider
,
OpenAIStopConditionsProvider
,
};
};
mod
aggregator
;
mod
aggregator
;
...
@@ -279,6 +280,10 @@ impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest {
...
@@ -279,6 +280,10 @@ impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest {
.extract_sampling_options
()
.extract_sampling_options
()
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Failed to extract sampling options: {}"
,
e
))
?
;
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Failed to extract sampling options: {}"
,
e
))
?
;
let
output_options
=
request
.extract_output_options
()
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Failed to extract output options: {}"
,
e
))
?
;
let
prompt
=
common
::
PromptType
::
Completion
(
common
::
CompletionContext
{
let
prompt
=
common
::
PromptType
::
Completion
(
common
::
CompletionContext
{
prompt
:
prompt_to_string
(
&
request
.inner.prompt
),
prompt
:
prompt_to_string
(
&
request
.inner.prompt
),
system_prompt
:
None
,
system_prompt
:
None
,
...
@@ -288,6 +293,7 @@ impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest {
...
@@ -288,6 +293,7 @@ impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest {
prompt
,
prompt
,
stop_conditions
,
stop_conditions
,
sampling_options
,
sampling_options
,
output_options
,
mdc_sum
:
None
,
mdc_sum
:
None
,
annotations
:
None
,
annotations
:
None
,
})
})
...
@@ -329,6 +335,26 @@ impl TryFrom<common::StreamingCompletionResponse> for async_openai::types::Choic
...
@@ -329,6 +335,26 @@ impl TryFrom<common::StreamingCompletionResponse> for async_openai::types::Choic
}
}
}
}
impl
OpenAIOutputOptionsProvider
for
NvCreateCompletionRequest
{
fn
get_logprobs
(
&
self
)
->
Option
<
u32
>
{
self
.inner.logprobs
.map
(|
logprobs
|
logprobs
as
u32
)
}
fn
get_prompt_logprobs
(
&
self
)
->
Option
<
u32
>
{
self
.inner
.echo
.and_then
(|
echo
|
if
echo
{
Some
(
1
)
}
else
{
None
})
}
fn
get_skip_special_tokens
(
&
self
)
->
Option
<
bool
>
{
None
}
fn
get_formatted_prompt
(
&
self
)
->
Option
<
bool
>
{
None
}
}
/// Implements `ValidateRequest` for `NvCreateCompletionRequest`,
/// Implements `ValidateRequest` for `NvCreateCompletionRequest`,
/// allowing us to validate the data.
/// allowing us to validate the data.
impl
ValidateRequest
for
NvCreateCompletionRequest
{
impl
ValidateRequest
for
NvCreateCompletionRequest
{
...
...
lib/llm/src/protocols/openai/completions/delta.rs
View file @
f476fd74
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
// limitations under the License.
// limitations under the License.
use
super
::{
NvCreateCompletionRequest
,
NvCreateCompletionResponse
};
use
super
::{
NvCreateCompletionRequest
,
NvCreateCompletionResponse
};
use
crate
::
protocols
::
common
;
use
crate
::
{
protocols
::
common
,
types
::
TokenIdType
}
;
impl
NvCreateCompletionRequest
{
impl
NvCreateCompletionRequest
{
// put this method on the request
// put this method on the request
...
@@ -22,7 +22,7 @@ impl NvCreateCompletionRequest {
...
@@ -22,7 +22,7 @@ impl NvCreateCompletionRequest {
pub
fn
response_generator
(
&
self
)
->
DeltaGenerator
{
pub
fn
response_generator
(
&
self
)
->
DeltaGenerator
{
let
options
=
DeltaGeneratorOptions
{
let
options
=
DeltaGeneratorOptions
{
enable_usage
:
true
,
enable_usage
:
true
,
enable_logprobs
:
false
,
enable_logprobs
:
self
.inner.logprobs
.unwrap_or
(
0
)
>
0
,
};
};
DeltaGenerator
::
new
(
self
.inner.model
.clone
(),
options
)
DeltaGenerator
::
new
(
self
.inner.model
.clone
(),
options
)
...
@@ -82,11 +82,74 @@ impl DeltaGenerator {
...
@@ -82,11 +82,74 @@ impl DeltaGenerator {
self
.usage.prompt_tokens
=
isl
;
self
.usage.prompt_tokens
=
isl
;
}
}
pub
fn
create_logprobs
(
&
self
,
tokens
:
Vec
<
common
::
llm_backend
::
TokenType
>
,
token_ids
:
Vec
<
TokenIdType
>
,
logprobs
:
Option
<
common
::
llm_backend
::
LogProbs
>
,
top_logprobs
:
Option
<
common
::
llm_backend
::
TopLogprobs
>
,
)
->
Option
<
async_openai
::
types
::
Logprobs
>
{
if
!
self
.options.enable_logprobs
||
logprobs
.is_none
()
{
return
None
;
}
let
toks
=
tokens
.into_iter
()
.zip
(
token_ids
)
.map
(|(
token
,
token_id
)|
(
token
.unwrap_or_default
(),
token_id
))
.collect
::
<
Vec
<
(
String
,
TokenIdType
)
>>
();
let
tok_lps
=
toks
.iter
()
.zip
(
logprobs
.unwrap
())
.map
(|(
_
,
lp
)|
lp
as
f32
)
.collect
::
<
Vec
<
f32
>>
();
let
top_lps
=
top_logprobs
.map_or
(
vec!
[],
|
top_logprobs
|
{
toks
.iter
()
.zip
(
tok_lps
.iter
())
.zip
(
top_logprobs
.iter
())
.map
(|(((
t
,
tid
),
lp
),
top_lps
)|
{
let
mut
found_selected_token
=
false
;
let
mut
converted_top_lps
=
top_lps
.iter
()
.map
(|
top_lp
|
{
let
top_t
=
top_lp
.token
.clone
()
.unwrap_or_default
();
let
top_tid
=
top_lp
.token_id
;
found_selected_token
=
found_selected_token
||
top_tid
==
*
tid
;
async_openai
::
types
::
TopLogprobs
{
token
:
top_t
,
logprob
:
top_lp
.logprob
as
f32
,
bytes
:
None
,
}
})
.collect
::
<
Vec
<
async_openai
::
types
::
TopLogprobs
>>
();
if
!
found_selected_token
{
// If the selected token is not in the top logprobs, add it
converted_top_lps
.push
(
async_openai
::
types
::
TopLogprobs
{
token
:
t
.clone
(),
logprob
:
*
lp
,
bytes
:
None
,
});
}
serde_json
::
to_value
(
converted_top_lps
)
.unwrap
()
})
.collect
()
});
Some
(
async_openai
::
types
::
Logprobs
{
tokens
:
toks
.iter
()
.map
(|(
t
,
_
)|
t
.clone
())
.collect
(),
token_logprobs
:
tok_lps
.into_iter
()
.map
(
Some
)
.collect
(),
text_offset
:
vec!
[],
top_logprobs
:
top_lps
,
})
}
pub
fn
create_choice
(
pub
fn
create_choice
(
&
self
,
&
self
,
index
:
u32
,
index
:
u32
,
text
:
Option
<
String
>
,
text
:
Option
<
String
>
,
finish_reason
:
Option
<
async_openai
::
types
::
CompletionFinishReason
>
,
finish_reason
:
Option
<
async_openai
::
types
::
CompletionFinishReason
>
,
logprobs
:
Option
<
async_openai
::
types
::
Logprobs
>
,
)
->
NvCreateCompletionResponse
{
)
->
NvCreateCompletionResponse
{
// todo - update for tool calling
// todo - update for tool calling
...
@@ -105,7 +168,7 @@ impl DeltaGenerator {
...
@@ -105,7 +168,7 @@ impl DeltaGenerator {
text
:
text
.unwrap_or_default
(),
text
:
text
.unwrap_or_default
(),
index
,
index
,
finish_reason
,
finish_reason
,
logprobs
:
None
,
logprobs
,
}],
}],
usage
:
if
self
.options.enable_usage
{
usage
:
if
self
.options.enable_usage
{
Some
(
usage
)
Some
(
usage
)
...
@@ -136,13 +199,18 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
...
@@ -136,13 +199,18 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
self
.usage.completion_tokens
+=
token_length
;
self
.usage.completion_tokens
+=
token_length
;
}
}
// TODO logprobs
let
logprobs
=
self
.create_logprobs
(
delta
.tokens
,
delta
.token_ids
,
delta
.log_probs
,
delta
.top_logprobs
,
);
let
finish_reason
=
delta
.finish_reason
.map
(
Into
::
into
);
let
finish_reason
=
delta
.finish_reason
.map
(
Into
::
into
);
// create choice
// create choice
let
index
=
delta
.index
.unwrap_or
(
0
);
let
index
=
delta
.index
.unwrap_or
(
0
);
let
response
=
self
.create_choice
(
index
,
delta
.text
.clone
(),
finish_reason
);
let
response
=
self
.create_choice
(
index
,
delta
.text
.clone
(),
finish_reason
,
logprobs
);
Ok
(
response
)
Ok
(
response
)
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment