Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
9d7c5df5
Unverified
Commit
9d7c5df5
authored
Jun 26, 2025
by
Paul Hendricks
Committed by
GitHub
Jun 26, 2025
Browse files
refactor: remove dead protocols code and organize imports idiomatically (#1669)
parent
03d976c7
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
34 additions
and
696 deletions
+34
-696
lib/llm/src/protocols.rs
lib/llm/src/protocols.rs
+2
-8
lib/llm/src/protocols/codec.rs
lib/llm/src/protocols/codec.rs
+2
-1
lib/llm/src/protocols/common.rs
lib/llm/src/protocols/common.rs
+7
-551
lib/llm/src/protocols/common/llm_backend.rs
lib/llm/src/protocols/common/llm_backend.rs
+2
-3
lib/llm/src/protocols/openai.rs
lib/llm/src/protocols/openai.rs
+7
-120
lib/llm/src/protocols/openai/chat_completions.rs
lib/llm/src/protocols/openai/chat_completions.rs
+4
-3
lib/llm/src/protocols/openai/chat_completions/aggregator.rs
lib/llm/src/protocols/openai/chat_completions/aggregator.rs
+4
-3
lib/llm/src/protocols/openai/embeddings.rs
lib/llm/src/protocols/openai/embeddings.rs
+2
-4
lib/llm/src/protocols/openai/embeddings/aggregator.rs
lib/llm/src/protocols/openai/embeddings/aggregator.rs
+4
-3
No files found.
lib/llm/src/protocols.rs
View file @
9d7c5df5
...
...
@@ -19,9 +19,10 @@
//! both publicly via the HTTP API and internally between Dynamo components.
//!
use
std
::
pin
::
Pin
;
use
futures
::{
Stream
,
StreamExt
};
use
serde
::{
Deserialize
,
Serialize
};
use
std
::
pin
::
Pin
;
pub
mod
codec
;
pub
mod
common
;
...
...
@@ -48,13 +49,6 @@ pub trait ContentProvider {
fn
content
(
&
self
)
->
String
;
}
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
struct
Usage
{
pub
prompt_tokens
:
i32
,
pub
completion_tokens
:
i32
,
pub
total_tokens
:
i32
,
}
/// Converts of a stream of [codec::Message]s into a stream of [Annotated]s.
pub
fn
convert_sse_stream
<
R
>
(
stream
:
DataStream
<
Result
<
codec
::
Message
,
codec
::
SseCodecError
>>
,
...
...
lib/llm/src/protocols/codec.rs
View file @
9d7c5df5
...
...
@@ -23,10 +23,11 @@
// TODO: Determine if we should use an External EventSource crate. There appear to be several
// potential candidates.
use
std
::{
io
::
Cursor
,
pin
::
Pin
};
use
bytes
::
BytesMut
;
use
futures
::
Stream
;
use
serde
::
Deserialize
;
use
std
::{
io
::
Cursor
,
pin
::
Pin
};
use
tokio_util
::
codec
::{
Decoder
,
FramedRead
,
LinesCodec
};
use
super
::
Annotated
;
...
...
lib/llm/src/protocols/common.rs
View file @
9d7c5df5
...
...
@@ -24,13 +24,10 @@
//! need some additional information to propagate intermediate results for improved observability.
//! The metadata is transferred via the other arms of the `StreamingResponse` enum.
//!
use
std
::
collections
::
HashMap
;
use
std
::
time
::
SystemTime
;
use
anyhow
::
Result
;
use
derive_builder
::
Builder
;
use
serde
::
ser
::
SerializeStruct
;
use
serde
::{
Deserialize
,
Deserializer
,
Serialize
,
Serializer
};
use
serde
::{
Deserialize
,
Serialize
};
use
super
::
TokenIdType
;
...
...
@@ -416,68 +413,6 @@ pub struct TopLogprob {
pub
bytes
:
Option
<
Vec
<
u8
>>
,
}
// /// UserData is a struct that contains user-defined data that can be passed to the inference engine.
// /// This information will be use to annotate the distributed traces for improved observability.
// #[derive(Serialize, Deserialize, Debug, Clone, Default)]
// pub struct UserData {
// /// Apply server-side prompt template to the request
// pub request_uuid: Option<uuid::Uuid>,
// }
/// StreamingResponse is the primary response object for the LLM Engine. The response stream
/// can emit three different types of messages. The Initialize and Finalize messages are optional
/// and primarily used over disaggreated transports to move states from the server to the client.
#[derive(Serialize,
Deserialize,
Debug)]
pub
enum
StreamingResponse
{
/// Initialize transports a Prologue object which communication the LLM Engine Context
Initialize
(
Option
<
Prologue
>
),
/// Step is the primary data in the response stream. It contains the StreamingCompletionResponse
Step
(
Box
<
StreamingCompletionResponse
>
),
/// Finalize is an optional final message in the response stream. It contains the Epilogue object which
/// is used to communicate extra information about the completion and the engine statistics.
Finalize
(
Option
<
Epilogue
>
),
}
// TODO(ryan) - this should be part of the internal api as it is not deserializble
// the public API should drop the Option<Arc<Stats>> in favor of Option<Stats>
// the two variants both serialize to the same json; however, the internal version
// can not be deserialized directly.
// we use the internal one on the server side to avoid the cost of cloning the Stats
// object; however, client side, we should always fully materialize the Stats object.
//
// TODO(ryan) - update this object to use an enum where we have the current definition be the
// StepResponse arm; then we will add the following arms:
// - Initialize(Prologue)
// - Step()
// - Finalize(Epilogue)
/// This is the first message that will be emitted by an Engine Response Stream
/// It indicates that the request has been preprocessed and queued for execution on the backend.
#[derive(Serialize,
Deserialize,
Debug)]
pub
struct
Prologue
{
/// If the request was preprocessed with a prompt template, this will contain the formatted prompt
pub
formatted_prompt
:
Option
<
String
>
,
/// If the request did not contain TokenIds, this will contain the token_ids that were generated
/// from tokenizing the prompt.
pub
input_token_ids
:
Option
<
Vec
<
TokenIdType
>>
,
}
/// This is the final message that will be emitted by a Engine Response Stream when it
/// finishes without error. In some cases, the engine may emit an error which will indicate
/// the end of the steam. Another case in which an Finalize(Epilogue) will not be emitted is
/// if the response handler has stalled and too many responses
#[derive(Serialize,
Deserialize,
Debug)]
pub
struct
Epilogue
{}
#[derive(Debug)]
pub
struct
StreamingCompletionResponse
{
pub
delta
:
Delta
,
pub
logprobs
:
Option
<
ChatCompletionLogprobs
>
,
}
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
enum
StreamState
{
Active
,
...
...
@@ -506,6 +441,12 @@ pub struct SequencePositionData {
pub
logprobs
:
Option
<
LogProbs
>
,
}
#[derive(Debug)]
pub
struct
StreamingCompletionResponse
{
pub
delta
:
Delta
,
pub
logprobs
:
Option
<
ChatCompletionLogprobs
>
,
}
// todo(ryan) - we need to create a DeltaBuilder which is a mutable object that can be passed
// around from the low-level compute engine to the high-level api. The DeltaBuilder will allow
// us to construct the Delta object at multiple layers in the streaming response path.
...
...
@@ -549,134 +490,6 @@ pub struct Usage {
pub
output_tokens_count
:
usize
,
}
// todo(ryan) - we need to update this object to make it more generic
// we need to define a set of generic stats traits that allow those stats to be None
// then back them by a concrete implementation like a TrtllmStats object
#[derive(Serialize,
Deserialize,
Debug,
Clone,
PartialEq)]
pub
struct
Stats
{
/// Time since the last Epoch/Forward Pass in microseconds (us).
/// This is measured and recorded by the Response Router rather then the
/// Inference Engine. Note, when evaluating the responses, if the this
/// values is greater then the stream's measured value, then there was a gap
/// between forward passes. In normal operation, the value of this field should
/// be less than the recorded value on the response stream.
pub
time_since_last_forward_pass_us
:
Option
<
u64
>
,
pub
request_active_count
:
u32
,
pub
request_context_count
:
u32
,
pub
request_generation_count
:
u32
,
pub
request_scheduled_count
:
u32
,
pub
request_max_count
:
u32
,
pub
kv_free_cache_blocks
:
u64
,
pub
kv_max_cache_blocks
:
u64
,
pub
kv_used_cache_blocks
:
u64
,
pub
kv_tokens_per_cache_block
:
u64
,
pub
runtime_cpu_memory_usage
:
u64
,
pub
runtime_gpu_memory_usage
:
u64
,
pub
runtime_pinned_memory_usage
:
u64
,
pub
iteration_counter
:
u64
,
pub
microbatch_id
:
u64
,
pub
total_context_tokens
:
u32
,
pub
timestamp
:
String
,
}
impl
Serialize
for
StreamingCompletionResponse
{
fn
serialize
<
S
>
(
&
self
,
serializer
:
S
)
->
Result
<
S
::
Ok
,
S
::
Error
>
where
S
:
Serializer
,
{
let
mut
state
=
serializer
.serialize_struct
(
"StreamingCompletionResponse"
,
2
)
?
;
// Serialize `delta` field
state
.serialize_field
(
"delta"
,
&
self
.delta
)
?
;
state
.end
()
}
}
impl
<
'de
>
Deserialize
<
'de
>
for
StreamingCompletionResponse
{
fn
deserialize
<
D
>
(
deserializer
:
D
)
->
Result
<
Self
,
D
::
Error
>
where
D
:
Deserializer
<
'de
>
,
{
// Create a temporary struct for deserialization
#[derive(Deserialize)]
struct
TempResponse
{
delta
:
Delta
,
logprobs
:
Option
<
ChatCompletionLogprobs
>
,
}
let
TempResponse
{
delta
,
logprobs
}
=
TempResponse
::
deserialize
(
deserializer
)
?
;
Ok
(
StreamingCompletionResponse
{
delta
,
logprobs
})
}
}
#[derive(Serialize,
Deserialize,
Debug)]
pub
struct
ScatterData
<
T
>
{
pub
x
:
Vec
<
T
>
,
pub
y
:
Vec
<
T
>
,
}
#[derive(Serialize,
Deserialize,
Debug)]
pub
struct
Trace
{
pub
time_to_first_token
:
u64
,
pub
token_to_token
:
Vec
<
u64
>
,
pub
start
:
SystemTime
,
pub
complete
:
SystemTime
,
pub
initial_tokens
:
u32
,
pub
max_tokens
:
u32
,
pub
t2ft_iteration_count
:
u64
,
pub
t2t_iteration_count
:
Vec
<
u64
>
,
}
#[derive(Serialize,
Deserialize,
Debug)]
pub
struct
PerformanceModel
{
// linear regression parameters fitting t2ft vs. initial tokens
pub
t2ft_intercept
:
f64
,
pub
t2ft_slope
:
f64
,
// linear regression parameters fitting t2tl vs. initial tokens
pub
t2tl_intercept
:
f64
,
pub
t2tl_slope
:
f64
,
// r2 values from the regression
pub
t2ft_fit_r2
:
f64
,
pub
t2tl_fit_r2
:
f64
,
}
#[derive(Serialize,
Deserialize,
Debug)]
pub
struct
CalibrationResults
{
pub
effective_flops
:
f64
,
pub
effective_memory_bandwidth
:
f64
,
pub
max_q
:
u32
,
pub
performance_model
:
PerformanceModel
,
pub
traces
:
Vec
<
Trace
>
,
pub
t2ft_scatter_data
:
ScatterData
<
f64
>
,
pub
t2tl_scatter_data
:
ScatterData
<
f64
>
,
}
#[derive(Serialize,
Deserialize,
Debug)]
pub
struct
LoadgenResults
{
pub
stats_by_iteration
:
HashMap
<
u64
,
Stats
>
,
pub
traces
:
Vec
<
Trace
>
,
}
impl
CompletionContext
{
/// Create a new CompletionContext
pub
fn
new
(
prompt
:
String
,
system_prompt
:
Option
<
String
>
)
->
Self
{
...
...
@@ -712,7 +525,6 @@ impl From<CompletionContext> for PromptType {
#[cfg(test)]
mod
tests
{
use
serde_json
;
use
super
::
*
;
...
...
@@ -759,360 +571,4 @@ mod tests {
panic!
(
"Expected a Completion variant"
);
}
}
// #[test]
// fn test_serialize_with_stats() {
// let response = StreamingCompletionResponse {
// delta: Delta {
// is_complete: true,
// finish_reason: Some(FinishReason::Length),
// token_ids: Some(vec![101, 102, 103]),
// tokens: Some(vec!["token1".to_string(), "token2".to_string()]),
// text: Some("example text".to_string()),
// sequence_length: Some(3),
// index: Some(0),
// cum_log_probs: Some(-0.5),
// err_msg: None,
// usage: None,
// },
// logprobs: None,
// };
// // Serialize the response
// let serialized = serde_json::to_string(&response).expect("Failed to serialize");
// // Expected JSON string (simplified)
// let expected = r#"{
// "delta": {
// "is_complete": true,
// "finish_reason": "length",
// "token_ids": [101, 102, 103],
// "tokens": ["token1", "token2"],
// "text": "example text",
// "sequence_length": 3,
// "index": 0,
// "cum_log_probs": -0.5,
// "err_msg": null,
// "usage": null
// },
// "stats": {
// "time_since_last_forward_pass_us": 1000,
// "request_active_count": 2,
// "request_context_count": 1,
// "request_generation_count": 3,
// "request_scheduled_count": 1,
// "request_max_count": 10,
// "kv_free_cache_blocks": 500,
// "kv_max_cache_blocks": 1000,
// "kv_used_cache_blocks": 500,
// "kv_tokens_per_cache_block": 10,
// "runtime_cpu_memory_usage": 5000,
// "runtime_gpu_memory_usage": 2000,
// "runtime_pinned_memory_usage": 1000,
// "iteration_counter": 5,
// "microbatch_id": 12345,
// "total_context_tokens": 256,
// "timestamp": "2024-01-01T00:00:00Z"
// }
// }"#;
// assert_eq!(
// serde_json::from_str::<serde_json::Value>(&serialized).unwrap(),
// serde_json::from_str::<serde_json::Value>(expected).unwrap()
// );
// }
#[test]
fn
test_serialize_without_stats
()
{
let
response
=
StreamingCompletionResponse
{
delta
:
Delta
{
is_complete
:
false
,
finish_reason
:
None
,
token_ids
:
None
,
tokens
:
None
,
text
:
None
,
sequence_length
:
None
,
index
:
None
,
cum_log_probs
:
None
,
err_msg
:
None
,
usage
:
None
,
},
logprobs
:
None
,
};
// Serialize the response
let
serialized
=
serde_json
::
to_string
(
&
response
)
.expect
(
"Failed to serialize"
);
// Expected JSON string
let
expected
=
r#"{
"delta": {
"is_complete": false,
"finish_reason": null,
"token_ids": null,
"tokens": null,
"text": null,
"sequence_length": null,
"index": null,
"cum_log_probs": null,
"err_msg": null,
"usage": null
}
}"#
;
assert_eq!
(
serde_json
::
from_str
::
<
serde_json
::
Value
>
(
&
serialized
)
.unwrap
(),
serde_json
::
from_str
::
<
serde_json
::
Value
>
(
expected
)
.unwrap
()
);
}
// #[test]
// fn test_deserialize_with_stats() {
// let json_data = r#"{
// "delta": {
// "is_complete": true,
// "finish_reason": "length",
// "token_ids": [101, 102, 103],
// "tokens": ["token1", "token2"],
// "text": "example text",
// "sequence_length": 3,
// "index": 0,
// "cum_log_probs": -0.5,
// "err_msg": null,
// "usage": null
// },
// "stats": {
// "time_since_last_forward_pass_us": 1000,
// "request_active_count": 2,
// "request_context_count": 1,
// "request_generation_count": 3,
// "request_scheduled_count": 1,
// "request_max_count": 10,
// "kv_free_cache_blocks": 500,
// "kv_max_cache_blocks": 1000,
// "kv_used_cache_blocks": 500,
// "kv_tokens_per_cache_block": 10,
// "runtime_cpu_memory_usage": 5000,
// "runtime_gpu_memory_usage": 2000,
// "runtime_pinned_memory_usage": 1000,
// "iteration_counter": 5,
// "microbatch_id": 12345,
// "total_context_tokens": 256,
// "timestamp": "2024-01-01T00:00:00Z"
// }
// }"#;
// // Deserialize the JSON string
// let deserialized: StreamingCompletionResponse =
// serde_json::from_str(json_data).expect("Failed to deserialize");
// // Expected response object
// let expected = StreamingCompletionResponse {
// delta: Delta {
// is_complete: true,
// finish_reason: Some(FinishReason::Length),
// token_ids: Some(vec![101, 102, 103]),
// tokens: Some(vec!["token1".to_string(), "token2".to_string()]),
// text: Some("example text".to_string()),
// sequence_length: Some(3),
// index: Some(0),
// cum_log_probs: Some(-0.5),
// err_msg: None,
// usage: None,
// },
// logprobs: None,
// };
// // This is wieldy but we can no longer do assert_eq!(deserialized, expected);
// // because the struct no longer has the PartialEq trait
// assert_eq!(deserialized.delta.is_complete, expected.delta.is_complete);
// assert_eq!(
// deserialized.delta.finish_reason,
// expected.delta.finish_reason
// );
// assert_eq!(deserialized.delta.token_ids, expected.delta.token_ids);
// assert_eq!(deserialized.delta.tokens, expected.delta.tokens);
// assert_eq!(deserialized.delta.text, expected.delta.text);
// assert_eq!(
// deserialized.delta.sequence_length,
// expected.delta.sequence_length
// );
// assert_eq!(deserialized.delta.index, expected.delta.index);
// assert_eq!(
// deserialized.delta.cum_log_probs,
// expected.delta.cum_log_probs
// );
// assert_eq!(deserialized.delta.err_msg, expected.delta.err_msg);
// assert_eq!(deserialized.delta.usage, expected.delta.usage);
// assert_eq!(
// deserialized_stats.time_since_last_forward_pass_us,
// expected_stats.time_since_last_forward_pass_us
// );
// assert_eq!(
// deserialized_stats.request_active_count,
// expected_stats.request_active_count
// );
// assert_eq!(
// deserialized_stats.request_context_count,
// expected_stats.request_context_count
// );
// assert_eq!(
// deserialized_stats.request_generation_count,
// expected_stats.request_generation_count
// );
// assert_eq!(
// deserialized_stats.request_scheduled_count,
// expected_stats.request_scheduled_count
// );
// assert_eq!(
// deserialized_stats.request_max_count,
// expected_stats.request_max_count
// );
// assert_eq!(
// deserialized_stats.kv_free_cache_blocks,
// expected_stats.kv_free_cache_blocks
// );
// assert_eq!(
// deserialized_stats.kv_max_cache_blocks,
// expected_stats.kv_max_cache_blocks
// );
// assert_eq!(
// deserialized_stats.kv_used_cache_blocks,
// expected_stats.kv_used_cache_blocks
// );
// assert_eq!(
// deserialized_stats.kv_tokens_per_cache_block,
// expected_stats.kv_tokens_per_cache_block
// );
// assert_eq!(
// deserialized_stats.runtime_cpu_memory_usage,
// expected_stats.runtime_cpu_memory_usage
// );
// assert_eq!(
// deserialized_stats.runtime_gpu_memory_usage,
// expected_stats.runtime_gpu_memory_usage
// );
// assert_eq!(
// deserialized_stats.runtime_pinned_memory_usage,
// expected_stats.runtime_pinned_memory_usage
// );
// assert_eq!(
// deserialized_stats.iteration_counter,
// expected_stats.iteration_counter
// );
// assert_eq!(
// deserialized_stats.microbatch_id,
// expected_stats.microbatch_id
// );
// assert_eq!(
// deserialized_stats.total_context_tokens,
// expected_stats.total_context_tokens
// );
// assert_eq!(deserialized_stats.timestamp, expected_stats.timestamp);
// }
#[test]
fn
test_deserialize_without_stats
()
{
let
json_data
=
r#"{
"delta": {
"is_complete": false,
"finish_reason": null,
"token_ids": null,
"tokens": null,
"text": null,
"sequence_length": null,
"index": null,
"cum_log_probs": null,
"err_msg": null,
"usage": null
}
}"#
;
// Deserialize the JSON string
let
deserialized
:
StreamingCompletionResponse
=
serde_json
::
from_str
(
json_data
)
.expect
(
"Failed to deserialize"
);
// Expected response object
let
expected
=
StreamingCompletionResponse
{
delta
:
Delta
{
is_complete
:
false
,
finish_reason
:
None
,
token_ids
:
None
,
tokens
:
None
,
text
:
None
,
sequence_length
:
None
,
index
:
None
,
cum_log_probs
:
None
,
err_msg
:
None
,
usage
:
None
,
},
logprobs
:
None
,
};
// This is wieldy but we can no longer do assert_eq!(deserialized, expected);
// because the struct no longer has the PartialEq trait
assert_eq!
(
deserialized
.delta.is_complete
,
expected
.delta.is_complete
);
assert_eq!
(
deserialized
.delta.finish_reason
,
expected
.delta.finish_reason
);
assert_eq!
(
deserialized
.delta.token_ids
,
expected
.delta.token_ids
);
assert_eq!
(
deserialized
.delta.tokens
,
expected
.delta.tokens
);
assert_eq!
(
deserialized
.delta.text
,
expected
.delta.text
);
assert_eq!
(
deserialized
.delta.sequence_length
,
expected
.delta.sequence_length
);
assert_eq!
(
deserialized
.delta.index
,
expected
.delta.index
);
assert_eq!
(
deserialized
.delta.cum_log_probs
,
expected
.delta.cum_log_probs
);
assert_eq!
(
deserialized
.delta.err_msg
,
expected
.delta.err_msg
);
assert_eq!
(
deserialized
.delta.usage
,
expected
.delta.usage
);
}
#[test]
fn
test_serialize_delta_and_none_stats
()
{
let
response
=
StreamingCompletionResponse
{
delta
:
Delta
{
is_complete
:
true
,
finish_reason
:
Some
(
FinishReason
::
Length
),
token_ids
:
Some
(
vec!
[
101
,
102
,
103
]),
tokens
:
Some
(
vec!
[
"token1"
.to_string
(),
"token2"
.to_string
()]),
text
:
Some
(
"example text"
.to_string
()),
sequence_length
:
Some
(
3
),
index
:
Some
(
0
),
cum_log_probs
:
Some
(
-
0.5
),
err_msg
:
None
,
usage
:
None
,
},
logprobs
:
None
,
};
// Serialize the response
let
serialized
=
serde_json
::
to_string
(
&
response
)
.expect
(
"Failed to serialize"
);
// Expected JSON string where stats is null
let
expected_json
=
r#"{
"delta": {
"is_complete": true,
"finish_reason": "length",
"token_ids": [101, 102, 103],
"tokens": ["token1", "token2"],
"text": "example text",
"sequence_length": 3,
"index": 0,
"cum_log_probs": -0.5,
"err_msg": null,
"usage": null
}
}"#
;
// Parse both the serialized response and the expected JSON as serde_json::Value for easy comparison
assert_eq!
(
serde_json
::
from_str
::
<
serde_json
::
Value
>
(
&
serialized
)
.unwrap
(),
serde_json
::
from_str
::
<
serde_json
::
Value
>
(
expected_json
)
.unwrap
()
);
}
}
lib/llm/src/protocols/common/llm_backend.rs
View file @
9d7c5df5
...
...
@@ -15,14 +15,13 @@
use
serde
::{
Deserialize
,
Serialize
};
pub
use
super
::
preprocessor
::
PreprocessedRequest
;
pub
use
super
::
FinishReason
;
use
crate
::
protocols
::
TokenIdType
;
pub
type
TokenType
=
Option
<
String
>
;
pub
type
LogProbs
=
Vec
<
f64
>
;
pub
use
super
::
preprocessor
::
PreprocessedRequest
;
pub
use
super
::
FinishReason
;
#[derive(Serialize,
Deserialize,
Debug,
Clone,
PartialEq)]
pub
struct
BackendOutput
{
/// New token_ids generated from the LLM Engine
...
...
lib/llm/src/protocols/openai.rs
View file @
9d7c5df5
...
...
@@ -13,24 +13,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
pub
mod
chat_completions
;
pub
mod
completions
;
pub
mod
embeddings
;
pub
mod
models
;
pub
mod
nvext
;
use
std
::
fmt
::
Display
;
use
anyhow
::
Result
;
use
serde
::{
Deserialize
,
Serialize
};
use
std
::{
fmt
::
Display
,
ops
::{
Add
,
Div
,
Mul
,
Sub
},
};
use
super
::{
common
::{
self
,
SamplingOptionsProvider
,
StopConditionsProvider
},
ContentProvider
,
};
pub
mod
chat_completions
;
pub
mod
completions
;
pub
mod
embeddings
;
pub
mod
models
;
pub
mod
nvext
;
/// Minimum allowed value for OpenAI's `temperature` sampling option
pub
const
MIN_TEMPERATURE
:
f32
=
0.0
;
...
...
@@ -67,22 +65,6 @@ pub const MAX_PRESENCE_PENALTY: f32 = 2.0;
/// Allowed range of values for OpenAI's `presence_penalty` sampling option
pub
const
PRESENCE_PENALTY_RANGE
:
(
f32
,
f32
)
=
(
MIN_PRESENCE_PENALTY
,
MAX_PRESENCE_PENALTY
);
/// Represents a streaming response from the OpenAI API
/// The object is generalized on R, which is the type of the response.
/// For SSE streaming responses, the expected `data: ` field is always a JSON
/// object corresponding to `R`; however, the comments in the SSE stream `: `
/// may correspond to other types of information, such as performance metrics,
/// as represented by other arms of this enum.
///
/// This is part of the common API as both the client and service need to agree
/// on the format of the streaming responses.
#[derive(Serialize,
Deserialize,
Debug)]
pub
enum
StreamingDelta
<
R
>
{
/// Represents a response delta from the API
Delta
(
R
),
Comment
(
String
),
}
#[derive(Serialize,
Deserialize,
Debug)]
pub
struct
AnnotatedDelta
<
R
>
{
pub
delta
:
R
,
...
...
@@ -183,43 +165,6 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
}
}
/// Common structure for chat completion responses; the only delta is the type of choices which differs
/// between streaming and non-streaming requests.
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
struct
GenericCompletionResponse
<
C
>
// where
// C: Serialize + Clone,
{
/// A unique identifier for the chat completion.
pub
id
:
String
,
/// A list of chat completion choices. Can be more than one if n is greater than 1.
pub
choices
:
Vec
<
C
>
,
/// The Unix timestamp (in seconds) of when the chat completion was created.
pub
created
:
u64
,
/// The model used for the chat completion.
pub
model
:
String
,
/// The object type, which is `chat.completion` if the type of `Choice` is `ChatCompletionChoice`,
/// or is `chat.completion.chunk` if the type of `Choice` is `ChatCompletionChoiceDelta`.
pub
object
:
String
,
pub
usage
:
Option
<
async_openai
::
types
::
CompletionUsage
>
,
/// This fingerprint represents the backend configuration that the model runs with.
///
/// Can be used in conjunction with the seed request parameter to understand when backend changes
/// have been made that might impact determinism.
///
/// NIM Compatibility:
/// This field is not supported by the NIM; however it will be added in the future.
/// The optional nature of this field will be relaxed when it is supported.
pub
system_fingerprint
:
Option
<
String
>
,
// TODO() - add NvResponseExtention
}
// todo - move to common location
fn
validate_range
<
T
>
(
value
:
Option
<
T
>
,
range
:
&
(
T
,
T
))
->
Result
<
Option
<
T
>>
where
...
...
@@ -235,30 +180,6 @@ where
Ok
(
Some
(
value
))
}
// todo - move to common location
/// scale value in `src` range to `dst` range
pub
fn
scale_value
<
T
>
(
value
:
&
T
,
src
:
&
(
T
,
T
),
dst
:
&
(
T
,
T
))
->
Result
<
T
>
where
T
:
Copy
+
PartialOrd
+
Add
<
Output
=
T
>
+
Sub
<
Output
=
T
>
+
Mul
<
Output
=
T
>
+
Div
<
Output
=
T
>
+
From
<
f32
>
,
{
let
dst_range
=
dst
.1
-
dst
.0
;
let
src_range
=
src
.1
-
src
.0
;
if
dst_range
==
T
::
from
(
0.0
)
{
anyhow
::
bail!
(
"dst range is 0"
);
}
if
src_range
==
T
::
from
(
0.0
)
{
anyhow
::
bail!
(
"src range is 0"
);
}
let
value_scaled
=
(
*
value
-
src
.0
)
/
src_range
;
Ok
(
dst
.0
+
(
value_scaled
*
dst_range
))
}
pub
trait
DeltaGeneratorExt
<
ResponseType
:
Send
+
Sync
+
'static
+
std
::
fmt
::
Debug
>
:
Send
+
Sync
+
'static
{
...
...
@@ -270,37 +191,3 @@ pub trait DeltaGeneratorExt<ResponseType: Send + Sync + 'static + std::fmt::Debu
/// Gets the current prompt token count (Input Sequence Length).
fn
get_isl
(
&
self
)
->
Option
<
u32
>
;
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
#[test]
fn
test_validate_range
()
{
assert_eq!
(
validate_range
(
Some
(
0.5
),
&
(
0.0
,
1.0
))
.unwrap
(),
Some
(
0.5
));
assert_eq!
(
validate_range
(
Some
(
0.0
),
&
(
0.0
,
1.0
))
.unwrap
(),
Some
(
0.0
));
assert_eq!
(
validate_range
(
Some
(
1.0
),
&
(
1.0
,
1.0
))
.unwrap
(),
Some
(
1.0
));
assert_eq!
(
validate_range
(
Some
(
1_i32
),
&
(
1
,
1
))
.unwrap
(),
Some
(
1
));
assert_eq!
(
validate_range
(
Some
(
1.1
),
&
(
0.0
,
1.0
))
.unwrap_err
()
.to_string
(),
"Value 1.1 is out of range [0, 1]"
);
assert_eq!
(
validate_range
(
Some
(
-
0.1
),
&
(
0.0
,
1.0
))
.unwrap_err
()
.to_string
(),
"Value -0.1 is out of range [0, 1]"
);
}
#[test]
fn
test_scaled_value
()
{
assert_eq!
(
scale_value
(
&
0.5
,
&
(
0.0
,
1.0
),
&
(
0.0
,
2.0
))
.unwrap
(),
1.0
);
assert_eq!
(
scale_value
(
&
0.0
,
&
(
0.0
,
1.0
),
&
(
0.0
,
2.0
))
.unwrap
(),
0.0
);
assert_eq!
(
scale_value
(
&-
1.0
,
&
(
-
2.0
,
2.0
),
&
(
1.0
,
2.0
))
.unwrap
(),
1.25
);
assert
!
(
scale_value
(
&
1.0
,
&
(
1.0
,
1.0
),
&
(
0.0
,
2.0
))
.is_err
());
}
}
lib/llm/src/protocols/openai/chat_completions.rs
View file @
9d7c5df5
...
...
@@ -13,13 +13,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use
dynamo_runtime
::
protocols
::
annotated
::
AnnotationsProvider
;
use
serde
::{
Deserialize
,
Serialize
};
use
validator
::
Validate
;
use
super
::
nvext
::
NvExt
;
use
super
::
nvext
::
NvExtProvider
;
use
super
::
OpenAISamplingOptionsProvider
;
use
super
::
OpenAIStopConditionsProvider
;
use
dynamo_runtime
::
protocols
::
annotated
::
AnnotationsProvider
;
use
serde
::{
Deserialize
,
Serialize
};
use
validator
::
Validate
;
mod
aggregator
;
mod
delta
;
...
...
lib/llm/src/protocols/openai/chat_completions/aggregator.rs
View file @
9d7c5df5
...
...
@@ -13,15 +13,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use
std
::{
collections
::
HashMap
,
pin
::
Pin
};
use
futures
::{
Stream
,
StreamExt
};
use
super
::{
NvCreateChatCompletionResponse
,
NvCreateChatCompletionStreamResponse
};
use
crate
::
protocols
::{
codec
::{
Message
,
SseCodecError
},
convert_sse_stream
,
Annotated
,
};
use
futures
::{
Stream
,
StreamExt
};
use
std
::{
collections
::
HashMap
,
pin
::
Pin
};
/// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`.
type
DataStream
<
T
>
=
Pin
<
Box
<
dyn
Stream
<
Item
=
T
>
+
Send
+
Sync
>>
;
...
...
lib/llm/src/protocols/openai/embeddings.rs
View file @
9d7c5df5
...
...
@@ -13,17 +13,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use
dynamo_runtime
::
protocols
::
annotated
::
AnnotationsProvider
;
use
serde
::{
Deserialize
,
Serialize
};
use
validator
::
Validate
;
mod
aggregator
;
mod
nvext
;
pub
use
nvext
::{
NvExt
,
NvExtProvider
};
// pub use delta::DeltaGenerator;
pub
use
aggregator
::
DeltaAggregator
;
use
dynamo_runtime
::
protocols
::
annotated
::
AnnotationsProvider
;
pub
use
nvext
::{
NvExt
,
NvExtProvider
};
#[derive(Serialize,
Deserialize,
Validate,
Debug,
Clone)]
pub
struct
NvCreateEmbeddingRequest
{
...
...
lib/llm/src/protocols/openai/embeddings/aggregator.rs
View file @
9d7c5df5
...
...
@@ -13,15 +13,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use
std
::
pin
::
Pin
;
use
futures
::{
Stream
,
StreamExt
};
use
super
::
NvCreateEmbeddingResponse
;
use
crate
::
protocols
::{
codec
::{
Message
,
SseCodecError
},
convert_sse_stream
,
Annotated
,
};
use
futures
::{
Stream
,
StreamExt
};
use
std
::
pin
::
Pin
;
/// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`.
type
DataStream
<
T
>
=
Pin
<
Box
<
dyn
Stream
<
Item
=
T
>
+
Send
+
Sync
>>
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment