Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
9d7c5df5
"lib/llm/src/vscode:/vscode.git/clone" did not exist on "66fd6f84ad7d10933aafd9fd1ef768447aa91b36"
Unverified
Commit
9d7c5df5
authored
Jun 26, 2025
by
Paul Hendricks
Committed by
GitHub
Jun 26, 2025
Browse files
refactor: remove dead protocols code and organize imports idiomatically (#1669)
parent
03d976c7
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
34 additions
and
696 deletions
+34
-696
lib/llm/src/protocols.rs
lib/llm/src/protocols.rs
+2
-8
lib/llm/src/protocols/codec.rs
lib/llm/src/protocols/codec.rs
+2
-1
lib/llm/src/protocols/common.rs
lib/llm/src/protocols/common.rs
+7
-551
lib/llm/src/protocols/common/llm_backend.rs
lib/llm/src/protocols/common/llm_backend.rs
+2
-3
lib/llm/src/protocols/openai.rs
lib/llm/src/protocols/openai.rs
+7
-120
lib/llm/src/protocols/openai/chat_completions.rs
lib/llm/src/protocols/openai/chat_completions.rs
+4
-3
lib/llm/src/protocols/openai/chat_completions/aggregator.rs
lib/llm/src/protocols/openai/chat_completions/aggregator.rs
+4
-3
lib/llm/src/protocols/openai/embeddings.rs
lib/llm/src/protocols/openai/embeddings.rs
+2
-4
lib/llm/src/protocols/openai/embeddings/aggregator.rs
lib/llm/src/protocols/openai/embeddings/aggregator.rs
+4
-3
No files found.
lib/llm/src/protocols.rs
View file @
9d7c5df5
...
...
@@ -19,9 +19,10 @@
//! both publicly via the HTTP API and internally between Dynamo components.
//!
use
std
::
pin
::
Pin
;
use
futures
::{
Stream
,
StreamExt
};
use
serde
::{
Deserialize
,
Serialize
};
use
std
::
pin
::
Pin
;
pub
mod
codec
;
pub
mod
common
;
...
...
@@ -48,13 +49,6 @@ pub trait ContentProvider {
fn
content
(
&
self
)
->
String
;
}
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
struct
Usage
{
pub
prompt_tokens
:
i32
,
pub
completion_tokens
:
i32
,
pub
total_tokens
:
i32
,
}
/// Converts of a stream of [codec::Message]s into a stream of [Annotated]s.
pub
fn
convert_sse_stream
<
R
>
(
stream
:
DataStream
<
Result
<
codec
::
Message
,
codec
::
SseCodecError
>>
,
...
...
lib/llm/src/protocols/codec.rs
View file @
9d7c5df5
...
...
@@ -23,10 +23,11 @@
// TODO: Determine if we should use an External EventSource crate. There appear to be several
// potential candidates.
use
std
::{
io
::
Cursor
,
pin
::
Pin
};
use
bytes
::
BytesMut
;
use
futures
::
Stream
;
use
serde
::
Deserialize
;
use
std
::{
io
::
Cursor
,
pin
::
Pin
};
use
tokio_util
::
codec
::{
Decoder
,
FramedRead
,
LinesCodec
};
use
super
::
Annotated
;
...
...
lib/llm/src/protocols/common.rs
View file @
9d7c5df5
...
...
@@ -24,13 +24,10 @@
//! need some additional information to propagate intermediate results for improved observability.
//! The metadata is transferred via the other arms of the `StreamingResponse` enum.
//!
use
std
::
collections
::
HashMap
;
use
std
::
time
::
SystemTime
;
use
anyhow
::
Result
;
use
derive_builder
::
Builder
;
use
serde
::
ser
::
SerializeStruct
;
use
serde
::{
Deserialize
,
Deserializer
,
Serialize
,
Serializer
};
use
serde
::{
Deserialize
,
Serialize
};
use
super
::
TokenIdType
;
...
...
@@ -416,68 +413,6 @@ pub struct TopLogprob {
pub
bytes
:
Option
<
Vec
<
u8
>>
,
}
// /// UserData is a struct that contains user-defined data that can be passed to the inference engine.
// /// This information will be use to annotate the distributed traces for improved observability.
// #[derive(Serialize, Deserialize, Debug, Clone, Default)]
// pub struct UserData {
// /// Apply server-side prompt template to the request
// pub request_uuid: Option<uuid::Uuid>,
// }
/// StreamingResponse is the primary response object for the LLM Engine. The response stream
/// can emit three different types of messages. The Initialize and Finalize messages are optional
/// and primarily used over disaggreated transports to move states from the server to the client.
#[derive(Serialize,
Deserialize,
Debug)]
pub
enum
StreamingResponse
{
/// Initialize transports a Prologue object which communication the LLM Engine Context
Initialize
(
Option
<
Prologue
>
),
/// Step is the primary data in the response stream. It contains the StreamingCompletionResponse
Step
(
Box
<
StreamingCompletionResponse
>
),
/// Finalize is an optional final message in the response stream. It contains the Epilogue object which
/// is used to communicate extra information about the completion and the engine statistics.
Finalize
(
Option
<
Epilogue
>
),
}
// TODO(ryan) - this should be part of the internal api as it is not deserializble
// the public API should drop the Option<Arc<Stats>> in favor of Option<Stats>
// the two variants both serialize to the same json; however, the internal version
// can not be deserialized directly.
// we use the internal one on the server side to avoid the cost of cloning the Stats
// object; however, client side, we should always fully materialize the Stats object.
//
// TODO(ryan) - update this object to use an enum where we have the current definition be the
// StepResponse arm; then we will add the following arms:
// - Initialize(Prologue)
// - Step()
// - Finalize(Epilogue)
/// This is the first message that will be emitted by an Engine Response Stream
/// It indicates that the request has been preprocessed and queued for execution on the backend.
#[derive(Serialize,
Deserialize,
Debug)]
pub
struct
Prologue
{
/// If the request was preprocessed with a prompt template, this will contain the formatted prompt
pub
formatted_prompt
:
Option
<
String
>
,
/// If the request did not contain TokenIds, this will contain the token_ids that were generated
/// from tokenizing the prompt.
pub
input_token_ids
:
Option
<
Vec
<
TokenIdType
>>
,
}
/// This is the final message that will be emitted by a Engine Response Stream when it
/// finishes without error. In some cases, the engine may emit an error which will indicate
/// the end of the steam. Another case in which an Finalize(Epilogue) will not be emitted is
/// if the response handler has stalled and too many responses
#[derive(Serialize,
Deserialize,
Debug)]
pub
struct
Epilogue
{}
#[derive(Debug)]
pub
struct
StreamingCompletionResponse
{
pub
delta
:
Delta
,
pub
logprobs
:
Option
<
ChatCompletionLogprobs
>
,
}
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
enum
StreamState
{
Active
,
...
...
@@ -506,6 +441,12 @@ pub struct SequencePositionData {
pub
logprobs
:
Option
<
LogProbs
>
,
}
#[derive(Debug)]
pub
struct
StreamingCompletionResponse
{
pub
delta
:
Delta
,
pub
logprobs
:
Option
<
ChatCompletionLogprobs
>
,
}
// todo(ryan) - we need to create a DeltaBuilder which is a mutable object that can be passed
// around from the low-level compute engine to the high-level api. The DeltaBuilder will allow
// us to construct the Delta object at multiple layers in the streaming response path.
...
...
@@ -549,134 +490,6 @@ pub struct Usage {
pub
output_tokens_count
:
usize
,
}
// todo(ryan) - we need to update this object to make it more generic
// we need to define a set of generic stats traits that allow those stats to be None
// then back them by a concrete implementation like a TrtllmStats object
#[derive(Serialize,
Deserialize,
Debug,
Clone,
PartialEq)]
pub
struct
Stats
{
/// Time since the last Epoch/Forward Pass in microseconds (us).
/// This is measured and recorded by the Response Router rather then the
/// Inference Engine. Note, when evaluating the responses, if the this
/// values is greater then the stream's measured value, then there was a gap
/// between forward passes. In normal operation, the value of this field should
/// be less than the recorded value on the response stream.
pub
time_since_last_forward_pass_us
:
Option
<
u64
>
,
pub
request_active_count
:
u32
,
pub
request_context_count
:
u32
,
pub
request_generation_count
:
u32
,
pub
request_scheduled_count
:
u32
,
pub
request_max_count
:
u32
,
pub
kv_free_cache_blocks
:
u64
,
pub
kv_max_cache_blocks
:
u64
,
pub
kv_used_cache_blocks
:
u64
,
pub
kv_tokens_per_cache_block
:
u64
,
pub
runtime_cpu_memory_usage
:
u64
,
pub
runtime_gpu_memory_usage
:
u64
,
pub
runtime_pinned_memory_usage
:
u64
,
pub
iteration_counter
:
u64
,
pub
microbatch_id
:
u64
,
pub
total_context_tokens
:
u32
,
pub
timestamp
:
String
,
}
impl
Serialize
for
StreamingCompletionResponse
{
fn
serialize
<
S
>
(
&
self
,
serializer
:
S
)
->
Result
<
S
::
Ok
,
S
::
Error
>
where
S
:
Serializer
,
{
let
mut
state
=
serializer
.serialize_struct
(
"StreamingCompletionResponse"
,
2
)
?
;
// Serialize `delta` field
state
.serialize_field
(
"delta"
,
&
self
.delta
)
?
;
state
.end
()
}
}
impl
<
'de
>
Deserialize
<
'de
>
for
StreamingCompletionResponse
{
fn
deserialize
<
D
>
(
deserializer
:
D
)
->
Result
<
Self
,
D
::
Error
>
where
D
:
Deserializer
<
'de
>
,
{
// Create a temporary struct for deserialization
#[derive(Deserialize)]
struct
TempResponse
{
delta
:
Delta
,
logprobs
:
Option
<
ChatCompletionLogprobs
>
,
}
let
TempResponse
{
delta
,
logprobs
}
=
TempResponse
::
deserialize
(
deserializer
)
?
;
Ok
(
StreamingCompletionResponse
{
delta
,
logprobs
})
}
}
#[derive(Serialize,
Deserialize,
Debug)]
pub
struct
ScatterData
<
T
>
{
pub
x
:
Vec
<
T
>
,
pub
y
:
Vec
<
T
>
,
}
#[derive(Serialize,
Deserialize,
Debug)]
pub
struct
Trace
{
pub
time_to_first_token
:
u64
,
pub
token_to_token
:
Vec
<
u64
>
,
pub
start
:
SystemTime
,
pub
complete
:
SystemTime
,
pub
initial_tokens
:
u32
,
pub
max_tokens
:
u32
,
pub
t2ft_iteration_count
:
u64
,
pub
t2t_iteration_count
:
Vec
<
u64
>
,
}
#[derive(Serialize,
Deserialize,
Debug)]
pub
struct
PerformanceModel
{
// linear regression parameters fitting t2ft vs. initial tokens
pub
t2ft_intercept
:
f64
,
pub
t2ft_slope
:
f64
,
// linear regression parameters fitting t2tl vs. initial tokens
pub
t2tl_intercept
:
f64
,
pub
t2tl_slope
:
f64
,
// r2 values from the regression
pub
t2ft_fit_r2
:
f64
,
pub
t2tl_fit_r2
:
f64
,
}
#[derive(Serialize,
Deserialize,
Debug)]
pub
struct
CalibrationResults
{
pub
effective_flops
:
f64
,
pub
effective_memory_bandwidth
:
f64
,
pub
max_q
:
u32
,
pub
performance_model
:
PerformanceModel
,
pub
traces
:
Vec
<
Trace
>
,
pub
t2ft_scatter_data
:
ScatterData
<
f64
>
,
pub
t2tl_scatter_data
:
ScatterData
<
f64
>
,
}
#[derive(Serialize,
Deserialize,
Debug)]
pub
struct
LoadgenResults
{
pub
stats_by_iteration
:
HashMap
<
u64
,
Stats
>
,
pub
traces
:
Vec
<
Trace
>
,
}
impl
CompletionContext
{
/// Create a new CompletionContext
pub
fn
new
(
prompt
:
String
,
system_prompt
:
Option
<
String
>
)
->
Self
{
...
...
@@ -712,7 +525,6 @@ impl From<CompletionContext> for PromptType {
#[cfg(test)]
mod
tests
{
use
serde_json
;
use
super
::
*
;
...
...
@@ -759,360 +571,4 @@ mod tests {
panic!
(
"Expected a Completion variant"
);
}
}
// #[test]
// fn test_serialize_with_stats() {
// let response = StreamingCompletionResponse {
// delta: Delta {
// is_complete: true,
// finish_reason: Some(FinishReason::Length),
// token_ids: Some(vec![101, 102, 103]),
// tokens: Some(vec!["token1".to_string(), "token2".to_string()]),
// text: Some("example text".to_string()),
// sequence_length: Some(3),
// index: Some(0),
// cum_log_probs: Some(-0.5),
// err_msg: None,
// usage: None,
// },
// logprobs: None,
// };
// // Serialize the response
// let serialized = serde_json::to_string(&response).expect("Failed to serialize");
// // Expected JSON string (simplified)
// let expected = r#"{
// "delta": {
// "is_complete": true,
// "finish_reason": "length",
// "token_ids": [101, 102, 103],
// "tokens": ["token1", "token2"],
// "text": "example text",
// "sequence_length": 3,
// "index": 0,
// "cum_log_probs": -0.5,
// "err_msg": null,
// "usage": null
// },
// "stats": {
// "time_since_last_forward_pass_us": 1000,
// "request_active_count": 2,
// "request_context_count": 1,
// "request_generation_count": 3,
// "request_scheduled_count": 1,
// "request_max_count": 10,
// "kv_free_cache_blocks": 500,
// "kv_max_cache_blocks": 1000,
// "kv_used_cache_blocks": 500,
// "kv_tokens_per_cache_block": 10,
// "runtime_cpu_memory_usage": 5000,
// "runtime_gpu_memory_usage": 2000,
// "runtime_pinned_memory_usage": 1000,
// "iteration_counter": 5,
// "microbatch_id": 12345,
// "total_context_tokens": 256,
// "timestamp": "2024-01-01T00:00:00Z"
// }
// }"#;
// assert_eq!(
// serde_json::from_str::<serde_json::Value>(&serialized).unwrap(),
// serde_json::from_str::<serde_json::Value>(expected).unwrap()
// );
// }
#[test]
fn
test_serialize_without_stats
()
{
let
response
=
StreamingCompletionResponse
{
delta
:
Delta
{
is_complete
:
false
,
finish_reason
:
None
,
token_ids
:
None
,
tokens
:
None
,
text
:
None
,
sequence_length
:
None
,
index
:
None
,
cum_log_probs
:
None
,
err_msg
:
None
,
usage
:
None
,
},
logprobs
:
None
,
};
// Serialize the response
let
serialized
=
serde_json
::
to_string
(
&
response
)
.expect
(
"Failed to serialize"
);
// Expected JSON string
let
expected
=
r#"{
"delta": {
"is_complete": false,
"finish_reason": null,
"token_ids": null,
"tokens": null,
"text": null,
"sequence_length": null,
"index": null,
"cum_log_probs": null,
"err_msg": null,
"usage": null
}
}"#
;
assert_eq!
(
serde_json
::
from_str
::
<
serde_json
::
Value
>
(
&
serialized
)
.unwrap
(),
serde_json
::
from_str
::
<
serde_json
::
Value
>
(
expected
)
.unwrap
()
);
}
// #[test]
// fn test_deserialize_with_stats() {
// let json_data = r#"{
// "delta": {
// "is_complete": true,
// "finish_reason": "length",
// "token_ids": [101, 102, 103],
// "tokens": ["token1", "token2"],
// "text": "example text",
// "sequence_length": 3,
// "index": 0,
// "cum_log_probs": -0.5,
// "err_msg": null,
// "usage": null
// },
// "stats": {
// "time_since_last_forward_pass_us": 1000,
// "request_active_count": 2,
// "request_context_count": 1,
// "request_generation_count": 3,
// "request_scheduled_count": 1,
// "request_max_count": 10,
// "kv_free_cache_blocks": 500,
// "kv_max_cache_blocks": 1000,
// "kv_used_cache_blocks": 500,
// "kv_tokens_per_cache_block": 10,
// "runtime_cpu_memory_usage": 5000,
// "runtime_gpu_memory_usage": 2000,
// "runtime_pinned_memory_usage": 1000,
// "iteration_counter": 5,
// "microbatch_id": 12345,
// "total_context_tokens": 256,
// "timestamp": "2024-01-01T00:00:00Z"
// }
// }"#;
// // Deserialize the JSON string
// let deserialized: StreamingCompletionResponse =
// serde_json::from_str(json_data).expect("Failed to deserialize");
// // Expected response object
// let expected = StreamingCompletionResponse {
// delta: Delta {
// is_complete: true,
// finish_reason: Some(FinishReason::Length),
// token_ids: Some(vec![101, 102, 103]),
// tokens: Some(vec!["token1".to_string(), "token2".to_string()]),
// text: Some("example text".to_string()),
// sequence_length: Some(3),
// index: Some(0),
// cum_log_probs: Some(-0.5),
// err_msg: None,
// usage: None,
// },
// logprobs: None,
// };
// // This is wieldy but we can no longer do assert_eq!(deserialized, expected);
// // because the struct no longer has the PartialEq trait
// assert_eq!(deserialized.delta.is_complete, expected.delta.is_complete);
// assert_eq!(
// deserialized.delta.finish_reason,
// expected.delta.finish_reason
// );
// assert_eq!(deserialized.delta.token_ids, expected.delta.token_ids);
// assert_eq!(deserialized.delta.tokens, expected.delta.tokens);
// assert_eq!(deserialized.delta.text, expected.delta.text);
// assert_eq!(
// deserialized.delta.sequence_length,
// expected.delta.sequence_length
// );
// assert_eq!(deserialized.delta.index, expected.delta.index);
// assert_eq!(
// deserialized.delta.cum_log_probs,
// expected.delta.cum_log_probs
// );
// assert_eq!(deserialized.delta.err_msg, expected.delta.err_msg);
// assert_eq!(deserialized.delta.usage, expected.delta.usage);
// assert_eq!(
// deserialized_stats.time_since_last_forward_pass_us,
// expected_stats.time_since_last_forward_pass_us
// );
// assert_eq!(
// deserialized_stats.request_active_count,
// expected_stats.request_active_count
// );
// assert_eq!(
// deserialized_stats.request_context_count,
// expected_stats.request_context_count
// );
// assert_eq!(
// deserialized_stats.request_generation_count,
// expected_stats.request_generation_count
// );
// assert_eq!(
// deserialized_stats.request_scheduled_count,
// expected_stats.request_scheduled_count
// );
// assert_eq!(
// deserialized_stats.request_max_count,
// expected_stats.request_max_count
// );
// assert_eq!(
// deserialized_stats.kv_free_cache_blocks,
// expected_stats.kv_free_cache_blocks
// );
// assert_eq!(
// deserialized_stats.kv_max_cache_blocks,
// expected_stats.kv_max_cache_blocks
// );
// assert_eq!(
// deserialized_stats.kv_used_cache_blocks,
// expected_stats.kv_used_cache_blocks
// );
// assert_eq!(
// deserialized_stats.kv_tokens_per_cache_block,
// expected_stats.kv_tokens_per_cache_block
// );
// assert_eq!(
// deserialized_stats.runtime_cpu_memory_usage,
// expected_stats.runtime_cpu_memory_usage
// );
// assert_eq!(
// deserialized_stats.runtime_gpu_memory_usage,
// expected_stats.runtime_gpu_memory_usage
// );
// assert_eq!(
// deserialized_stats.runtime_pinned_memory_usage,
// expected_stats.runtime_pinned_memory_usage
// );
// assert_eq!(
// deserialized_stats.iteration_counter,
// expected_stats.iteration_counter
// );
// assert_eq!(
// deserialized_stats.microbatch_id,
// expected_stats.microbatch_id
// );
// assert_eq!(
// deserialized_stats.total_context_tokens,
// expected_stats.total_context_tokens
// );
// assert_eq!(deserialized_stats.timestamp, expected_stats.timestamp);
// }
#[test]
fn
test_deserialize_without_stats
()
{
let
json_data
=
r#"{
"delta": {
"is_complete": false,
"finish_reason": null,
"token_ids": null,
"tokens": null,
"text": null,
"sequence_length": null,
"index": null,
"cum_log_probs": null,
"err_msg": null,
"usage": null
}
}"#
;
// Deserialize the JSON string
let
deserialized
:
StreamingCompletionResponse
=
serde_json
::
from_str
(
json_data
)
.expect
(
"Failed to deserialize"
);
// Expected response object
let
expected
=
StreamingCompletionResponse
{
delta
:
Delta
{
is_complete
:
false
,
finish_reason
:
None
,
token_ids
:
None
,
tokens
:
None
,
text
:
None
,
sequence_length
:
None
,
index
:
None
,
cum_log_probs
:
None
,
err_msg
:
None
,
usage
:
None
,
},
logprobs
:
None
,
};
// This is wieldy but we can no longer do assert_eq!(deserialized, expected);
// because the struct no longer has the PartialEq trait
assert_eq!
(
deserialized
.delta.is_complete
,
expected
.delta.is_complete
);
assert_eq!
(
deserialized
.delta.finish_reason
,
expected
.delta.finish_reason
);
assert_eq!
(
deserialized
.delta.token_ids
,
expected
.delta.token_ids
);
assert_eq!
(
deserialized
.delta.tokens
,
expected
.delta.tokens
);
assert_eq!
(
deserialized
.delta.text
,
expected
.delta.text
);
assert_eq!
(
deserialized
.delta.sequence_length
,
expected
.delta.sequence_length
);
assert_eq!
(
deserialized
.delta.index
,
expected
.delta.index
);
assert_eq!
(
deserialized
.delta.cum_log_probs
,
expected
.delta.cum_log_probs
);
assert_eq!
(
deserialized
.delta.err_msg
,
expected
.delta.err_msg
);
assert_eq!
(
deserialized
.delta.usage
,
expected
.delta.usage
);
}
#[test]
fn
test_serialize_delta_and_none_stats
()
{
let
response
=
StreamingCompletionResponse
{
delta
:
Delta
{
is_complete
:
true
,
finish_reason
:
Some
(
FinishReason
::
Length
),
token_ids
:
Some
(
vec!
[
101
,
102
,
103
]),
tokens
:
Some
(
vec!
[
"token1"
.to_string
(),
"token2"
.to_string
()]),
text
:
Some
(
"example text"
.to_string
()),
sequence_length
:
Some
(
3
),
index
:
Some
(
0
),
cum_log_probs
:
Some
(
-
0.5
),
err_msg
:
None
,
usage
:
None
,
},
logprobs
:
None
,
};
// Serialize the response
let
serialized
=
serde_json
::
to_string
(
&
response
)
.expect
(
"Failed to serialize"
);
// Expected JSON string where stats is null
let
expected_json
=
r#"{
"delta": {
"is_complete": true,
"finish_reason": "length",
"token_ids": [101, 102, 103],
"tokens": ["token1", "token2"],
"text": "example text",
"sequence_length": 3,
"index": 0,
"cum_log_probs": -0.5,
"err_msg": null,
"usage": null
}
}"#
;
// Parse both the serialized response and the expected JSON as serde_json::Value for easy comparison
assert_eq!
(
serde_json
::
from_str
::
<
serde_json
::
Value
>
(
&
serialized
)
.unwrap
(),
serde_json
::
from_str
::
<
serde_json
::
Value
>
(
expected_json
)
.unwrap
()
);
}
}
lib/llm/src/protocols/common/llm_backend.rs
View file @
9d7c5df5
...
...
@@ -15,14 +15,13 @@
use
serde
::{
Deserialize
,
Serialize
};
pub
use
super
::
preprocessor
::
PreprocessedRequest
;
pub
use
super
::
FinishReason
;
use
crate
::
protocols
::
TokenIdType
;
pub
type
TokenType
=
Option
<
String
>
;
pub
type
LogProbs
=
Vec
<
f64
>
;
pub
use
super
::
preprocessor
::
PreprocessedRequest
;
pub
use
super
::
FinishReason
;
#[derive(Serialize,
Deserialize,
Debug,
Clone,
PartialEq)]
pub
struct
BackendOutput
{
/// New token_ids generated from the LLM Engine
...
...
lib/llm/src/protocols/openai.rs
View file @
9d7c5df5
...
...
@@ -13,24 +13,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
pub
mod
chat_completions
;
pub
mod
completions
;
pub
mod
embeddings
;
pub
mod
models
;
pub
mod
nvext
;
use
std
::
fmt
::
Display
;
use
anyhow
::
Result
;
use
serde
::{
Deserialize
,
Serialize
};
use
std
::{
fmt
::
Display
,
ops
::{
Add
,
Div
,
Mul
,
Sub
},
};
use
super
::{
common
::{
self
,
SamplingOptionsProvider
,
StopConditionsProvider
},
ContentProvider
,
};
pub
mod
chat_completions
;
pub
mod
completions
;
pub
mod
embeddings
;
pub
mod
models
;
pub
mod
nvext
;
/// Minimum allowed value for OpenAI's `temperature` sampling option
pub
const
MIN_TEMPERATURE
:
f32
=
0.0
;
...
...
@@ -67,22 +65,6 @@ pub const MAX_PRESENCE_PENALTY: f32 = 2.0;
/// Allowed range of values for OpenAI's `presence_penalty` sampling option
pub
const
PRESENCE_PENALTY_RANGE
:
(
f32
,
f32
)
=
(
MIN_PRESENCE_PENALTY
,
MAX_PRESENCE_PENALTY
);
/// Represents a streaming response from the OpenAI API
/// The object is generalized on R, which is the type of the response.
/// For SSE streaming responses, the expected `data: ` field is always a JSON
/// object corresponding to `R`; however, the comments in the SSE stream `: `
/// may correspond to other types of information, such as performance metrics,
/// as represented by other arms of this enum.
///
/// This is part of the common API as both the client and service need to agree
/// on the format of the streaming responses.
#[derive(Serialize,
Deserialize,
Debug)]
pub
enum
StreamingDelta
<
R
>
{
/// Represents a response delta from the API
Delta
(
R
),
Comment
(
String
),
}
#[derive(Serialize,
Deserialize,
Debug)]
pub
struct
AnnotatedDelta
<
R
>
{
pub
delta
:
R
,
...
...
@@ -183,43 +165,6 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
}
}
/// Common structure for chat completion responses; the only delta is the type of choices which differs
/// between streaming and non-streaming requests.
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
struct
GenericCompletionResponse
<
C
>
// where
// C: Serialize + Clone,
{
/// A unique identifier for the chat completion.
pub
id
:
String
,
/// A list of chat completion choices. Can be more than one if n is greater than 1.
pub
choices
:
Vec
<
C
>
,
/// The Unix timestamp (in seconds) of when the chat completion was created.
pub
created
:
u64
,
/// The model used for the chat completion.
pub
model
:
String
,
/// The object type, which is `chat.completion` if the type of `Choice` is `ChatCompletionChoice`,
/// or is `chat.completion.chunk` if the type of `Choice` is `ChatCompletionChoiceDelta`.
pub
object
:
String
,
pub
usage
:
Option
<
async_openai
::
types
::
CompletionUsage
>
,
/// This fingerprint represents the backend configuration that the model runs with.
///
/// Can be used in conjunction with the seed request parameter to understand when backend changes
/// have been made that might impact determinism.
///
/// NIM Compatibility:
/// This field is not supported by the NIM; however it will be added in the future.
/// The optional nature of this field will be relaxed when it is supported.
pub
system_fingerprint
:
Option
<
String
>
,
// TODO() - add NvResponseExtention
}
// todo - move to common location
fn
validate_range
<
T
>
(
value
:
Option
<
T
>
,
range
:
&
(
T
,
T
))
->
Result
<
Option
<
T
>>
where
...
...
@@ -235,30 +180,6 @@ where
Ok
(
Some
(
value
))
}
// todo - move to common location
/// scale value in `src` range to `dst` range
pub
fn
scale_value
<
T
>
(
value
:
&
T
,
src
:
&
(
T
,
T
),
dst
:
&
(
T
,
T
))
->
Result
<
T
>
where
T
:
Copy
+
PartialOrd
+
Add
<
Output
=
T
>
+
Sub
<
Output
=
T
>
+
Mul
<
Output
=
T
>
+
Div
<
Output
=
T
>
+
From
<
f32
>
,
{
let
dst_range
=
dst
.1
-
dst
.0
;
let
src_range
=
src
.1
-
src
.0
;
if
dst_range
==
T
::
from
(
0.0
)
{
anyhow
::
bail!
(
"dst range is 0"
);
}
if
src_range
==
T
::
from
(
0.0
)
{
anyhow
::
bail!
(
"src range is 0"
);
}
let
value_scaled
=
(
*
value
-
src
.0
)
/
src_range
;
Ok
(
dst
.0
+
(
value_scaled
*
dst_range
))
}
pub
trait
DeltaGeneratorExt
<
ResponseType
:
Send
+
Sync
+
'static
+
std
::
fmt
::
Debug
>
:
Send
+
Sync
+
'static
{
...
...
@@ -270,37 +191,3 @@ pub trait DeltaGeneratorExt<ResponseType: Send + Sync + 'static + std::fmt::Debu
/// Gets the current prompt token count (Input Sequence Length).
fn
get_isl
(
&
self
)
->
Option
<
u32
>
;
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
#[test]
fn
test_validate_range
()
{
assert_eq!
(
validate_range
(
Some
(
0.5
),
&
(
0.0
,
1.0
))
.unwrap
(),
Some
(
0.5
));
assert_eq!
(
validate_range
(
Some
(
0.0
),
&
(
0.0
,
1.0
))
.unwrap
(),
Some
(
0.0
));
assert_eq!
(
validate_range
(
Some
(
1.0
),
&
(
1.0
,
1.0
))
.unwrap
(),
Some
(
1.0
));
assert_eq!
(
validate_range
(
Some
(
1_i32
),
&
(
1
,
1
))
.unwrap
(),
Some
(
1
));
assert_eq!
(
validate_range
(
Some
(
1.1
),
&
(
0.0
,
1.0
))
.unwrap_err
()
.to_string
(),
"Value 1.1 is out of range [0, 1]"
);
assert_eq!
(
validate_range
(
Some
(
-
0.1
),
&
(
0.0
,
1.0
))
.unwrap_err
()
.to_string
(),
"Value -0.1 is out of range [0, 1]"
);
}
#[test]
fn
test_scaled_value
()
{
assert_eq!
(
scale_value
(
&
0.5
,
&
(
0.0
,
1.0
),
&
(
0.0
,
2.0
))
.unwrap
(),
1.0
);
assert_eq!
(
scale_value
(
&
0.0
,
&
(
0.0
,
1.0
),
&
(
0.0
,
2.0
))
.unwrap
(),
0.0
);
assert_eq!
(
scale_value
(
&-
1.0
,
&
(
-
2.0
,
2.0
),
&
(
1.0
,
2.0
))
.unwrap
(),
1.25
);
assert
!
(
scale_value
(
&
1.0
,
&
(
1.0
,
1.0
),
&
(
0.0
,
2.0
))
.is_err
());
}
}
lib/llm/src/protocols/openai/chat_completions.rs
View file @
9d7c5df5
...
...
@@ -13,13 +13,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use
dynamo_runtime
::
protocols
::
annotated
::
AnnotationsProvider
;
use
serde
::{
Deserialize
,
Serialize
};
use
validator
::
Validate
;
use
super
::
nvext
::
NvExt
;
use
super
::
nvext
::
NvExtProvider
;
use
super
::
OpenAISamplingOptionsProvider
;
use
super
::
OpenAIStopConditionsProvider
;
use
dynamo_runtime
::
protocols
::
annotated
::
AnnotationsProvider
;
use
serde
::{
Deserialize
,
Serialize
};
use
validator
::
Validate
;
mod
aggregator
;
mod
delta
;
...
...
lib/llm/src/protocols/openai/chat_completions/aggregator.rs
View file @
9d7c5df5
...
...
@@ -13,15 +13,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use
std
::{
collections
::
HashMap
,
pin
::
Pin
};
use
futures
::{
Stream
,
StreamExt
};
use
super
::{
NvCreateChatCompletionResponse
,
NvCreateChatCompletionStreamResponse
};
use
crate
::
protocols
::{
codec
::{
Message
,
SseCodecError
},
convert_sse_stream
,
Annotated
,
};
use
futures
::{
Stream
,
StreamExt
};
use
std
::{
collections
::
HashMap
,
pin
::
Pin
};
/// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`.
type
DataStream
<
T
>
=
Pin
<
Box
<
dyn
Stream
<
Item
=
T
>
+
Send
+
Sync
>>
;
...
...
lib/llm/src/protocols/openai/embeddings.rs
View file @
9d7c5df5
...
...
@@ -13,17 +13,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use
dynamo_runtime
::
protocols
::
annotated
::
AnnotationsProvider
;
use
serde
::{
Deserialize
,
Serialize
};
use
validator
::
Validate
;
mod
aggregator
;
mod
nvext
;
pub
use
nvext
::{
NvExt
,
NvExtProvider
};
// pub use delta::DeltaGenerator;
pub
use
aggregator
::
DeltaAggregator
;
use
dynamo_runtime
::
protocols
::
annotated
::
AnnotationsProvider
;
pub
use
nvext
::{
NvExt
,
NvExtProvider
};
#[derive(Serialize,
Deserialize,
Validate,
Debug,
Clone)]
pub
struct
NvCreateEmbeddingRequest
{
...
...
lib/llm/src/protocols/openai/embeddings/aggregator.rs
View file @
9d7c5df5
...
...
@@ -13,15 +13,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use
std
::
pin
::
Pin
;
use
futures
::{
Stream
,
StreamExt
};
use
super
::
NvCreateEmbeddingResponse
;
use
crate
::
protocols
::{
codec
::{
Message
,
SseCodecError
},
convert_sse_stream
,
Annotated
,
};
use
futures
::{
Stream
,
StreamExt
};
use
std
::
pin
::
Pin
;
/// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`.
type
DataStream
<
T
>
=
Pin
<
Box
<
dyn
Stream
<
Item
=
T
>
+
Send
+
Sync
>>
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment