Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
057f8f47
Commit
057f8f47
authored
Feb 28, 2025
by
Graham King
Committed by
GitHub
Feb 28, 2025
Browse files
feat: TensorRT-LLM engine (#317)
Engine, `tio` support and docs. Proof of concept / experimental.
parent
11a36651
Changes
46
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
556 additions
and
0 deletions
+556
-0
lib/llm/src/engines/trtllm/executor/processors/kv.rs
lib/llm/src/engines/trtllm/executor/processors/kv.rs
+98
-0
lib/llm/src/engines/trtllm/executor/processors/response.rs
lib/llm/src/engines/trtllm/executor/processors/response.rs
+165
-0
lib/llm/src/engines/trtllm/executor/protocols.rs
lib/llm/src/engines/trtllm/executor/protocols.rs
+173
-0
lib/llm/src/engines/trtllm/executor/protocols/kv.rs
lib/llm/src/engines/trtllm/executor/protocols/kv.rs
+16
-0
lib/llm/src/engines/trtllm/executor/protocols/outputs.rs
lib/llm/src/engines/trtllm/executor/protocols/outputs.rs
+82
-0
lib/llm/src/engines/trtllm/executor/protocols/stats.rs
lib/llm/src/engines/trtllm/executor/protocols/stats.rs
+22
-0
No files found.
lib/llm/src/engines/trtllm/executor/processors/kv.rs
0 → 100644
View file @
057f8f47
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
crate
::
kv_router
::
protocols
::
KvCacheEvents
;
use
std
::{
sync
::{
atomic
::{
AtomicBool
,
Ordering
},
Arc
,
Weak
,
},
thread
,
};
use
tokio
::
sync
::
broadcast
;
use
super
::
*
;
const
KV_EVENT_CHANNEL_CAPACITY
:
usize
=
65536
;
type
EventChannelType
=
broadcast
::
Sender
<
KvCacheEvents
>
;
pub
type
KvEventSubscriptionChannel
=
broadcast
::
Receiver
<
KvCacheEvents
>
;
pub
struct
KvEventProcessor
{
handle
:
thread
::
JoinHandle
<
()
>
,
shutdown
:
Arc
<
AtomicBool
>
,
channel
:
Weak
<
EventChannelType
>
,
}
impl
KvEventProcessor
{
/// Creates a new KV Event Processor
pub
fn
new
(
state
:
ProcessorState
)
->
Self
{
// Shutdown Token
let
shutdown
=
Arc
::
new
(
AtomicBool
::
new
(
false
));
let
shutdown_clone
=
shutdown
.clone
();
// Event Channel
let
channel
=
Arc
::
new
(
broadcast
::
channel
(
KV_EVENT_CHANNEL_CAPACITY
)
.0
);
let
channel_clone
=
channel
.clone
();
let
handle
=
std
::
thread
::
spawn
(
move
||
{
process_events
(
state
,
shutdown_clone
,
channel_clone
);
});
KvEventProcessor
{
handle
,
shutdown
,
channel
:
Arc
::
downgrade
(
&
channel
),
}
}
/// Subscribes to the KV Events broadcast channel
/// Multiple subscribers can be created to monitor the KV Events
pub
fn
subscribe
(
&
self
)
->
Option
<
broadcast
::
Receiver
<
KvCacheEvents
>>
{
self
.channel
.upgrade
()
.map
(|
channel
|
channel
.subscribe
())
}
/// Joins the thread and waits for it to finish
pub
fn
join
(
self
)
->
thread
::
Result
<
()
>
{
self
.shutdown
.store
(
true
,
Ordering
::
Relaxed
);
self
.handle
.join
()
}
}
fn
process_events
(
state
:
ProcessorState
,
shutdown
:
Arc
<
AtomicBool
>
,
channel
:
Arc
<
EventChannelType
>
,
)
{
loop
{
// this blocks the thread until the response is ready or the server is shutdown
let
mut
message
=
state
.executor
.await_kv_events
()
.expect
(
"Failed to await responses"
);
let
should_shutdown
=
message
.shutdown
||
shutdown
.load
(
Ordering
::
Relaxed
);
message
.shutdown
=
should_shutdown
;
if
let
Err
(
e
)
=
channel
.send
(
message
)
{
tracing
::
debug!
(
"Failed to send message to channel: {:?}"
,
e
);
}
if
should_shutdown
{
tracing
::
debug!
(
"Shutting down KV Event Processor"
);
break
;
}
}
}
lib/llm/src/engines/trtllm/executor/processors/response.rs
0 → 100644
View file @
057f8f47
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
std
::
thread
;
use
tokio
::
sync
::
mpsc
;
use
super
::
*
;
use
crate
::
engines
::
trtllm
::
executor
::
ResponseQueues
;
pub
struct
ResponseProcessor
{
handle
:
thread
::
JoinHandle
<
()
>
,
}
impl
ResponseProcessor
{
pub
fn
new
(
state
:
ProcessorState
,
response_queues
:
ResponseQueues
)
->
Self
{
let
handle
=
std
::
thread
::
spawn
(
move
||
{
process_responses
(
state
,
response_queues
);
});
ResponseProcessor
{
handle
}
}
/// Block and wait for the response processor to finish
pub
fn
join
(
self
)
->
thread
::
Result
<
()
>
{
self
.handle
.join
()
}
}
#[derive(Debug,
thiserror::Error)]
enum
ResponseError
{
#[error(
"Response queue dropped; possible client disconnect"
)]
ResponseQueueDropped
,
#[error(
"Response channel closed; possible client disconnect"
)]
ChannelClosed
,
#[error(
"Response channel full; backpress detected in response stream"
)]
ChannelFull
,
#[error(
"Invalid response: no error or result found"
)]
InvalidResponse
,
/// Error indicating that TensorRT LLM returned an error
/// This also indicates that the request was not successful and no further responses
/// will be sent for this request
#[error(
"TensorRT LLM Engine Error: {0}"
)]
EngineError
(
String
),
#[error(
"Completed successfully"
)]
RequestComplete
,
}
fn
process_responses
(
state
:
ProcessorState
,
response_queues
:
ResponseQueues
)
{
loop
{
// this blocks the thread until the response is ready or the server is shutdown
let
message
=
state
.executor
.await_responses
()
.expect
(
"Failed to await responses"
);
// check shutdown condition
if
message
.shutdown
{
tracing
::
info!
(
"Server shutdown detected"
);
break
;
}
// process responses - hold the lock while we iterate to avoid any contention
// grabbing and releasing it for each response
let
mut
queues
=
response_queues
.lock
()
.unwrap
();
for
output
in
message
.responses
{
let
request_id
=
output
.request_id
;
let
client_id
=
output
.client_id
.expect
(
"client_id is missing"
);
let
tx
=
queues
.get
(
&
client_id
);
match
try_send
(
tx
,
output
)
{
Ok
(
_
)
=>
{}
Err
(
e
)
=>
{
tracing
::
trace!
(
client_id
,
"processing response: {}"
,
e
);
match
e
{
ResponseError
::
InvalidResponse
=>
{
// this would likely be a bug on the server; we expect the oneof to be set
tracing
::
warn!
(
client_id
,
"Invalid response; No action required"
);
}
ResponseError
::
EngineError
(
_
)
=>
{
// no need to cancel, the server will not send any more responses
queues
.remove
(
&
client_id
);
}
ResponseError
::
ChannelFull
=>
{
// critical error
tracing
::
error!
(
client_id
,
"Alert: backpressure detected in response stream"
);
state
.executor
.cancel_request
(
request_id
);
queues
.remove
(
&
client_id
);
}
ResponseError
::
ChannelClosed
=>
{
// the first indication the client has disconnected
state
.executor
.cancel_request
(
request_id
);
queues
.remove
(
&
client_id
);
}
ResponseError
::
ResponseQueueDropped
=>
{
// if we get a response for a dropped queue, we need to cancel the request
state
.executor
.cancel_request
(
request_id
);
}
ResponseError
::
RequestComplete
=>
{
// no need to cancel, the server will not send any more responses
queues
.remove
(
&
client_id
);
}
}
}
}
}
}
}
fn
try_send
(
tx
:
Option
<&
mpsc
::
Sender
<
Result
<
protocols
::
Output
>>>
,
response
:
protocols
::
Response
,
)
->
Result
<
(),
ResponseError
>
{
let
mut
rc
=
Ok
(());
let
tx
=
tx
.ok_or
(
ResponseError
::
ResponseQueueDropped
)
?
;
let
result
=
match
(
response
.output
,
response
.error_msg
)
{
(
Some
(
output
),
None
)
=>
{
if
output
.is_final
{
rc
=
Err
(
ResponseError
::
RequestComplete
);
}
Ok
(
output
)
}
(
None
,
Some
(
e
))
=>
{
rc
=
Err
(
ResponseError
::
EngineError
(
e
.clone
()));
Err
(
ResponseError
::
EngineError
(
e
.clone
()))
}
(
None
,
None
)
=>
return
Err
(
ResponseError
::
InvalidResponse
),
(
Some
(
_
),
Some
(
_
))
=>
return
Err
(
ResponseError
::
InvalidResponse
),
};
match
tx
.try_send
(
result
.map_err
(|
e
|
e
.into
()))
{
Ok
(
_
)
=>
{}
Err
(
e
)
=>
match
e
{
mpsc
::
error
::
TrySendError
::
Closed
(
_
)
=>
{
return
Err
(
ResponseError
::
ChannelClosed
);
}
mpsc
::
error
::
TrySendError
::
Full
(
_
)
=>
{
return
Err
(
ResponseError
::
ChannelFull
);
}
},
}
rc
}
lib/llm/src/engines/trtllm/executor/protocols.rs
0 → 100644
View file @
057f8f47
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
derive_builder
::
Builder
;
use
serde
::{
Deserialize
,
Serialize
};
use
serde_repr
::{
Deserialize_repr
,
Serialize_repr
};
pub
mod
kv
;
pub
mod
outputs
;
pub
mod
stats
;
pub
use
outputs
::
*
;
#[derive(Serialize,
Deserialize,
Default)]
pub
struct
SamplingConfig
{
pub
beam_width
:
u32
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_k
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_p
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_p_min
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_p_reset_ids
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_p_decay
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
seed
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
temperature
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
min_tokens
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
beam_search_diversity_rate
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
repetition_penalty
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
presence_penalty
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
frequency_penalty
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
length_penalty
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
early_stopping
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
no_repeat_ngram_size
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
num_return_sequences
:
Option
<
u32
>
,
}
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
struct
OutputConfig
{
pub
return_log_probs
:
bool
,
pub
return_context_logits
:
bool
,
pub
return_generation_logits
:
bool
,
pub
exclude_input_from_output
:
bool
,
pub
return_encoder_output
:
bool
,
}
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
struct
RetentionPriorityAndDuration
{
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
retention_priority
:
Option
<
u32
>
,
// google.protobuf.UInt32Value
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
duration_ms
:
Option
<
u64
>
,
// google.protobuf.UInt64Value
}
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
struct
TokenRangeRetentionConfig
{
pub
token_start
:
u32
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
token_end
:
Option
<
u32
>
,
// google.protobuf.UInt32Value
pub
priority
:
u32
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
duration_ms
:
Option
<
u64
>
,
// google.protobuf.UInt64Value
}
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
struct
KvCacheRetentionConfig
{
pub
token_range_retention_configs
:
Vec
<
TokenRangeRetentionConfig
>
,
pub
decode_retention_priority
:
u32
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
decode_duration_ms
:
Option
<
u64
>
,
// google.protobuf.UInt64Value
}
#[derive(Serialize,
Deserialize,
Debug,
Clone,
Builder)]
pub
struct
Request
{
pub
input_token_ids
:
Vec
<
u32
>
,
pub
max_tokens
:
u32
,
pub
streaming
:
bool
,
// pub sampling_config: SamplingConfig,
// pub output_config: OutputConfig,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
end_id
:
Option
<
u32
>
,
// pub pad_id: Option<u32>, // google.protobuf.UInt32Value
// pub position_ids: Vec<u32>,
// pub bad_words: Vec<u32>,
// pub stop_words: Vec<u32>,
// pub embedding_bias: Vec<u8>, // bytes
// // TODO: Add external_draft_tokens_config: ExternalDraftTokensConfig
// // TODO: Add prompt_tuning_config: PromptTuningConfig
// // TODO: Add lora_config: LoraConfig
// // TODO: Add lookahead_config: LookaheadDecodingConfig
// pub kv_cache_retention_config: KvCacheRetentionConfig,
// pub logits_post_processor_name: String,
// pub encoder_input_token_ids: Vec<u32>,
// pub client_id: Option<u64>, // google.protobuf.UInt64Value
// pub return_all_generated_tokens: bool,
// pub priority: f32,
// pub request_type: u32,
// // TODO: Add context_phase_params: ContextPhaseParams
// pub encoder_input_features: Vec<u8>, // bytes
// pub encoder_output_length: Option<u32>, // google.protobuf.UInt32Value
// pub cross_attention_mask: Vec<u8>, // bytes
// pub num_return_sequences: u32,
// // TODO: Add eagle_config: EagleConfig
// pub skip_cross_attn_blocks: Vec<u8>, // bytes
}
// todo - return a Result
impl
Request
{
pub
fn
new
(
input_token_ids
:
Vec
<
u32
>
,
max_tokens
:
u32
)
->
Self
{
RequestBuilder
::
default
()
.input_token_ids
(
input_token_ids
)
.max_tokens
(
max_tokens
)
.streaming
(
true
)
.build
()
.unwrap
()
}
}
// todo convert to a TryFrom
impl
From
<
crate
::
protocols
::
common
::
llm_backend
::
BackendInput
>
for
Request
{
fn
from
(
input
:
crate
::
protocols
::
common
::
llm_backend
::
BackendInput
)
->
Self
{
let
request
=
RequestBuilder
::
default
()
.input_token_ids
(
input
.token_ids
)
.max_tokens
(
input
.stop_conditions.max_tokens
.unwrap_or
(
16
))
.streaming
(
true
)
.end_id
(
input
.eos_token_ids
.last
()
.cloned
())
.build
()
.unwrap
();
request
}
}
lib/llm/src/engines/trtllm/executor/protocols/kv.rs
0 → 100644
View file @
057f8f47
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub
use
crate
::
kv_router
::
protocols
::
ForwardPassMetrics
;
lib/llm/src/engines/trtllm/executor/protocols/outputs.rs
0 → 100644
View file @
057f8f47
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
super
::
*
;
use
crate
::
protocols
::{
common
::{
self
},
TokenIdType
,
};
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
struct
Responses
{
pub
responses
:
Vec
<
Response
>
,
pub
shutdown
:
bool
,
}
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
struct
Response
{
pub
request_id
:
u64
,
pub
client_id
:
Option
<
u64
>
,
// Optional client ID.
pub
error_msg
:
Option
<
String
>
,
// Error message if the request failed.
pub
output
:
Option
<
Output
>
,
// Output if the request succeeded.
}
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
struct
Output
{
pub
is_final
:
bool
,
pub
token_ids
:
Vec
<
TokenIdType
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
cum_log_prob
:
Option
<
f64
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
log_probs
:
Option
<
Vec
<
f64
>>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
finish_reason
:
Option
<
FinishReasonEnum
>
,
}
#[derive(Serialize_repr,
Deserialize_repr,
Debug,
Clone)]
#[repr(u8)]
pub
enum
FinishReasonEnum
{
FinishReasonNotDone
=
0
,
FinishReasonEos
=
1
,
FinishReasonStop
=
2
,
FinishReasonLength
=
3
,
}
impl
From
<
Output
>
for
common
::
llm_backend
::
LLMEngineOutput
{
fn
from
(
output
:
Output
)
->
Self
{
let
finish_reason
=
match
output
.finish_reason
{
Some
(
FinishReasonEnum
::
FinishReasonNotDone
)
=>
None
,
Some
(
FinishReasonEnum
::
FinishReasonEos
)
=>
Some
(
common
::
FinishReason
::
EoS
),
Some
(
FinishReasonEnum
::
FinishReasonStop
)
=>
Some
(
common
::
FinishReason
::
Stop
),
Some
(
FinishReasonEnum
::
FinishReasonLength
)
=>
Some
(
common
::
FinishReason
::
Length
),
None
=>
None
,
};
common
::
llm_backend
::
LLMEngineOutput
{
// todo - propagate mdcsum
token_ids
:
output
.token_ids
,
tokens
:
None
,
text
:
None
,
cum_log_probs
:
output
.cum_log_prob
,
log_probs
:
None
,
finish_reason
,
}
}
}
lib/llm/src/engines/trtllm/executor/protocols/stats.rs
0 → 100644
View file @
057f8f47
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
super
::
kv
::
ForwardPassMetrics
;
use
serde
::{
Deserialize
,
Serialize
};
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
struct
IterStats
{
pub
stats
:
Vec
<
ForwardPassMetrics
>
,
}
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment