Commit 057f8f47 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: TensorRT-LLM engine (#317)

Engine, `tio` support and docs.

Proof of concept / experimental.
parent 11a36651
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::kv_router::protocols::KvCacheEvents;
use std::{
sync::{
atomic::{AtomicBool, Ordering},
Arc, Weak,
},
thread,
};
use tokio::sync::broadcast;
use super::*;
const KV_EVENT_CHANNEL_CAPACITY: usize = 65536;
type EventChannelType = broadcast::Sender<KvCacheEvents>;
pub type KvEventSubscriptionChannel = broadcast::Receiver<KvCacheEvents>;
pub struct KvEventProcessor {
handle: thread::JoinHandle<()>,
shutdown: Arc<AtomicBool>,
channel: Weak<EventChannelType>,
}
impl KvEventProcessor {
/// Creates a new KV Event Processor
pub fn new(state: ProcessorState) -> Self {
// Shutdown Token
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_clone = shutdown.clone();
// Event Channel
let channel = Arc::new(broadcast::channel(KV_EVENT_CHANNEL_CAPACITY).0);
let channel_clone = channel.clone();
let handle = std::thread::spawn(move || {
process_events(state, shutdown_clone, channel_clone);
});
KvEventProcessor {
handle,
shutdown,
channel: Arc::downgrade(&channel),
}
}
/// Subscribes to the KV Events broadcast channel
/// Multiple subscribers can be created to monitor the KV Events
pub fn subscribe(&self) -> Option<broadcast::Receiver<KvCacheEvents>> {
self.channel.upgrade().map(|channel| channel.subscribe())
}
/// Joins the thread and waits for it to finish
pub fn join(self) -> thread::Result<()> {
self.shutdown.store(true, Ordering::Relaxed);
self.handle.join()
}
}
fn process_events(
state: ProcessorState,
shutdown: Arc<AtomicBool>,
channel: Arc<EventChannelType>,
) {
loop {
// this blocks the thread until the response is ready or the server is shutdown
let mut message = state
.executor
.await_kv_events()
.expect("Failed to await responses");
let should_shutdown = message.shutdown || shutdown.load(Ordering::Relaxed);
message.shutdown = should_shutdown;
if let Err(e) = channel.send(message) {
tracing::debug!("Failed to send message to channel: {:?}", e);
}
if should_shutdown {
tracing::debug!("Shutting down KV Event Processor");
break;
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::thread;
use tokio::sync::mpsc;
use super::*;
use crate::engines::trtllm::executor::ResponseQueues;
pub struct ResponseProcessor {
handle: thread::JoinHandle<()>,
}
impl ResponseProcessor {
pub fn new(state: ProcessorState, response_queues: ResponseQueues) -> Self {
let handle = std::thread::spawn(move || {
process_responses(state, response_queues);
});
ResponseProcessor { handle }
}
/// Block and wait for the response processor to finish
pub fn join(self) -> thread::Result<()> {
self.handle.join()
}
}
#[derive(Debug, thiserror::Error)]
enum ResponseError {
#[error("Response queue dropped; possible client disconnect")]
ResponseQueueDropped,
#[error("Response channel closed; possible client disconnect")]
ChannelClosed,
#[error("Response channel full; backpress detected in response stream")]
ChannelFull,
#[error("Invalid response: no error or result found")]
InvalidResponse,
/// Error indicating that TensorRT LLM returned an error
/// This also indicates that the request was not successful and no further responses
/// will be sent for this request
#[error("TensorRT LLM Engine Error: {0}")]
EngineError(String),
#[error("Completed successfully")]
RequestComplete,
}
fn process_responses(state: ProcessorState, response_queues: ResponseQueues) {
loop {
// this blocks the thread until the response is ready or the server is shutdown
let message = state
.executor
.await_responses()
.expect("Failed to await responses");
// check shutdown condition
if message.shutdown {
tracing::info!("Server shutdown detected");
break;
}
// process responses - hold the lock while we iterate to avoid any contention
// grabbing and releasing it for each response
let mut queues = response_queues.lock().unwrap();
for output in message.responses {
let request_id = output.request_id;
let client_id = output.client_id.expect("client_id is missing");
let tx = queues.get(&client_id);
match try_send(tx, output) {
Ok(_) => {}
Err(e) => {
tracing::trace!(client_id, "processing response: {}", e);
match e {
ResponseError::InvalidResponse => {
// this would likely be a bug on the server; we expect the oneof to be set
tracing::warn!(client_id, "Invalid response; No action required");
}
ResponseError::EngineError(_) => {
// no need to cancel, the server will not send any more responses
queues.remove(&client_id);
}
ResponseError::ChannelFull => {
// critical error
tracing::error!(
client_id,
"Alert: backpressure detected in response stream"
);
state.executor.cancel_request(request_id);
queues.remove(&client_id);
}
ResponseError::ChannelClosed => {
// the first indication the client has disconnected
state.executor.cancel_request(request_id);
queues.remove(&client_id);
}
ResponseError::ResponseQueueDropped => {
// if we get a response for a dropped queue, we need to cancel the request
state.executor.cancel_request(request_id);
}
ResponseError::RequestComplete => {
// no need to cancel, the server will not send any more responses
queues.remove(&client_id);
}
}
}
}
}
}
}
fn try_send(
tx: Option<&mpsc::Sender<Result<protocols::Output>>>,
response: protocols::Response,
) -> Result<(), ResponseError> {
let mut rc = Ok(());
let tx = tx.ok_or(ResponseError::ResponseQueueDropped)?;
let result = match (response.output, response.error_msg) {
(Some(output), None) => {
if output.is_final {
rc = Err(ResponseError::RequestComplete);
}
Ok(output)
}
(None, Some(e)) => {
rc = Err(ResponseError::EngineError(e.clone()));
Err(ResponseError::EngineError(e.clone()))
}
(None, None) => return Err(ResponseError::InvalidResponse),
(Some(_), Some(_)) => return Err(ResponseError::InvalidResponse),
};
match tx.try_send(result.map_err(|e| e.into())) {
Ok(_) => {}
Err(e) => match e {
mpsc::error::TrySendError::Closed(_) => {
return Err(ResponseError::ChannelClosed);
}
mpsc::error::TrySendError::Full(_) => {
return Err(ResponseError::ChannelFull);
}
},
}
rc
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
pub mod kv;
pub mod outputs;
pub mod stats;
pub use outputs::*;
#[derive(Serialize, Deserialize, Default)]
pub struct SamplingConfig {
pub beam_width: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p_min: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p_reset_ids: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p_decay: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub beam_search_diversity_rate: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub repetition_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub length_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub early_stopping: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub no_repeat_ngram_size: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub num_return_sequences: Option<u32>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct OutputConfig {
pub return_log_probs: bool,
pub return_context_logits: bool,
pub return_generation_logits: bool,
pub exclude_input_from_output: bool,
pub return_encoder_output: bool,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct RetentionPriorityAndDuration {
#[serde(skip_serializing_if = "Option::is_none")]
pub retention_priority: Option<u32>, // google.protobuf.UInt32Value
#[serde(skip_serializing_if = "Option::is_none")]
pub duration_ms: Option<u64>, // google.protobuf.UInt64Value
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct TokenRangeRetentionConfig {
pub token_start: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub token_end: Option<u32>, // google.protobuf.UInt32Value
pub priority: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub duration_ms: Option<u64>, // google.protobuf.UInt64Value
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct KvCacheRetentionConfig {
pub token_range_retention_configs: Vec<TokenRangeRetentionConfig>,
pub decode_retention_priority: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub decode_duration_ms: Option<u64>, // google.protobuf.UInt64Value
}
#[derive(Serialize, Deserialize, Debug, Clone, Builder)]
pub struct Request {
pub input_token_ids: Vec<u32>,
pub max_tokens: u32,
pub streaming: bool,
// pub sampling_config: SamplingConfig,
// pub output_config: OutputConfig,
#[serde(skip_serializing_if = "Option::is_none")]
pub end_id: Option<u32>,
// pub pad_id: Option<u32>, // google.protobuf.UInt32Value
// pub position_ids: Vec<u32>,
// pub bad_words: Vec<u32>,
// pub stop_words: Vec<u32>,
// pub embedding_bias: Vec<u8>, // bytes
// // TODO: Add external_draft_tokens_config: ExternalDraftTokensConfig
// // TODO: Add prompt_tuning_config: PromptTuningConfig
// // TODO: Add lora_config: LoraConfig
// // TODO: Add lookahead_config: LookaheadDecodingConfig
// pub kv_cache_retention_config: KvCacheRetentionConfig,
// pub logits_post_processor_name: String,
// pub encoder_input_token_ids: Vec<u32>,
// pub client_id: Option<u64>, // google.protobuf.UInt64Value
// pub return_all_generated_tokens: bool,
// pub priority: f32,
// pub request_type: u32,
// // TODO: Add context_phase_params: ContextPhaseParams
// pub encoder_input_features: Vec<u8>, // bytes
// pub encoder_output_length: Option<u32>, // google.protobuf.UInt32Value
// pub cross_attention_mask: Vec<u8>, // bytes
// pub num_return_sequences: u32,
// // TODO: Add eagle_config: EagleConfig
// pub skip_cross_attn_blocks: Vec<u8>, // bytes
}
// todo - return a Result
impl Request {
pub fn new(input_token_ids: Vec<u32>, max_tokens: u32) -> Self {
RequestBuilder::default()
.input_token_ids(input_token_ids)
.max_tokens(max_tokens)
.streaming(true)
.build()
.unwrap()
}
}
// todo convert to a TryFrom
impl From<crate::protocols::common::llm_backend::BackendInput> for Request {
fn from(input: crate::protocols::common::llm_backend::BackendInput) -> Self {
let request = RequestBuilder::default()
.input_token_ids(input.token_ids)
.max_tokens(input.stop_conditions.max_tokens.unwrap_or(16))
.streaming(true)
.end_id(input.eos_token_ids.last().cloned())
.build()
.unwrap();
request
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub use crate::kv_router::protocols::ForwardPassMetrics;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
use crate::protocols::{
common::{self},
TokenIdType,
};
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Responses {
pub responses: Vec<Response>,
pub shutdown: bool,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Response {
pub request_id: u64,
pub client_id: Option<u64>, // Optional client ID.
pub error_msg: Option<String>, // Error message if the request failed.
pub output: Option<Output>, // Output if the request succeeded.
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Output {
pub is_final: bool,
pub token_ids: Vec<TokenIdType>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cum_log_prob: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub log_probs: Option<Vec<f64>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<FinishReasonEnum>,
}
#[derive(Serialize_repr, Deserialize_repr, Debug, Clone)]
#[repr(u8)]
pub enum FinishReasonEnum {
FinishReasonNotDone = 0,
FinishReasonEos = 1,
FinishReasonStop = 2,
FinishReasonLength = 3,
}
impl From<Output> for common::llm_backend::LLMEngineOutput {
fn from(output: Output) -> Self {
let finish_reason = match output.finish_reason {
Some(FinishReasonEnum::FinishReasonNotDone) => None,
Some(FinishReasonEnum::FinishReasonEos) => Some(common::FinishReason::EoS),
Some(FinishReasonEnum::FinishReasonStop) => Some(common::FinishReason::Stop),
Some(FinishReasonEnum::FinishReasonLength) => Some(common::FinishReason::Length),
None => None,
};
common::llm_backend::LLMEngineOutput {
// todo - propagate mdcsum
token_ids: output.token_ids,
tokens: None,
text: None,
cum_log_probs: output.cum_log_prob,
log_probs: None,
finish_reason,
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::kv::ForwardPassMetrics;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct IterStats {
pub stats: Vec<ForwardPassMetrics>,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment