Unverified Commit 675a9bf5 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore: Remove TRT-LLM C++ engine in favor of Python one (#747)

parent d797b4ba
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "engine_trt/request.hpp"
#include <nlohmann/json.hpp>
#include <spdlog/spdlog.h>
#include <cstdint>
#include <optional>
#include <string>
#include <vector>
using json = nlohmann::json;
namespace ex = tensorrt_llm::executor;
namespace nvidia::nvllm::trt {
// SamplingConfig Struct
struct SamplingConfig
{
uint32_t beam_width = 1;
std::optional<uint32_t> top_k;
std::optional<float> top_p;
std::optional<float> top_p_min;
std::optional<uint32_t> top_p_reset_ids;
std::optional<float> top_p_decay;
std::optional<uint32_t> seed;
std::optional<float> temperature;
std::optional<uint32_t> min_tokens;
std::optional<float> beam_search_diversity_rate;
std::optional<float> repetition_penalty;
std::optional<float> presence_penalty;
std::optional<float> frequency_penalty;
std::optional<float> length_penalty;
std::optional<uint32_t> early_stopping;
std::optional<uint32_t> no_repeat_ngram_size;
std::optional<uint32_t> num_return_sequences;
ex::SamplingConfig to_executor_config() const
{
return ex::SamplingConfig(beam_width,
top_k,
top_p,
top_p_min,
top_p_reset_ids,
top_p_decay,
seed,
temperature,
min_tokens,
beam_search_diversity_rate,
repetition_penalty,
presence_penalty,
frequency_penalty,
length_penalty,
early_stopping,
no_repeat_ngram_size,
num_return_sequences);
}
};
// Custom to_json and from_json functions for SamplingConfig
inline void to_json(json& j, const SamplingConfig& s)
{
j = json{{"beam_width", s.beam_width}};
if (s.top_k)
j["top_k"] = s.top_k.value();
if (s.top_p)
j["top_p"] = s.top_p.value();
if (s.top_p_min)
j["top_p_min"] = s.top_p_min.value();
if (s.top_p_reset_ids)
j["top_p_reset_ids"] = s.top_p_reset_ids.value();
if (s.top_p_decay)
j["top_p_decay"] = s.top_p_decay.value();
if (s.seed)
j["seed"] = s.seed.value();
if (s.temperature)
j["temperature"] = s.temperature.value();
if (s.min_tokens)
j["min_tokens"] = s.min_tokens.value();
if (s.beam_search_diversity_rate)
j["beam_search_diversity_rate"] = s.beam_search_diversity_rate.value();
if (s.repetition_penalty)
j["repetition_penalty"] = s.repetition_penalty.value();
if (s.presence_penalty)
j["presence_penalty"] = s.presence_penalty.value();
if (s.frequency_penalty)
j["frequency_penalty"] = s.frequency_penalty.value();
if (s.length_penalty)
j["length_penalty"] = s.length_penalty.value();
if (s.early_stopping)
j["early_stopping"] = s.early_stopping.value();
if (s.no_repeat_ngram_size)
j["no_repeat_ngram_size"] = s.no_repeat_ngram_size.value();
if (s.num_return_sequences)
j["num_return_sequences"] = s.num_return_sequences.value();
}
inline void from_json(const json& j, SamplingConfig& s)
{
j.at("beam_width").get_to(s.beam_width);
if (j.contains("top_k"))
s.top_k = j.at("top_k").get<uint32_t>();
if (j.contains("top_p"))
s.top_p = j.at("top_p").get<float>();
if (j.contains("top_p_min"))
s.top_p_min = j.at("top_p_min").get<float>();
if (j.contains("top_p_reset_ids"))
s.top_p_reset_ids = j.at("top_p_reset_ids").get<uint32_t>();
if (j.contains("top_p_decay"))
s.top_p_decay = j.at("top_p_decay").get<float>();
if (j.contains("seed"))
s.seed = j.at("seed").get<uint32_t>();
if (j.contains("temperature"))
s.temperature = j.at("temperature").get<float>();
if (j.contains("min_tokens"))
s.min_tokens = j.at("min_tokens").get<uint32_t>();
if (j.contains("beam_search_diversity_rate"))
s.beam_search_diversity_rate = j.at("beam_search_diversity_rate").get<float>();
if (j.contains("repetition_penalty"))
s.repetition_penalty = j.at("repetition_penalty").get<float>();
if (j.contains("presence_penalty"))
s.presence_penalty = j.at("presence_penalty").get<float>();
if (j.contains("frequency_penalty"))
s.frequency_penalty = j.at("frequency_penalty").get<float>();
if (j.contains("length_penalty"))
s.length_penalty = j.at("length_penalty").get<float>();
if (j.contains("early_stopping"))
s.early_stopping = j.at("early_stopping").get<uint32_t>();
if (j.contains("no_repeat_ngram_size"))
s.no_repeat_ngram_size = j.at("no_repeat_ngram_size").get<uint32_t>();
if (j.contains("num_return_sequences"))
s.num_return_sequences = j.at("num_return_sequences").get<uint32_t>();
}
// OutputConfig Struct
struct OutputConfig
{
bool return_log_probs;
bool return_context_logits;
bool return_generation_logits;
bool exclude_input_from_output;
bool return_encoder_output;
};
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(OutputConfig,
return_log_probs,
return_context_logits,
return_generation_logits,
exclude_input_from_output,
return_encoder_output)
// RetentionPriorityAndDuration Struct
struct RetentionPriorityAndDuration
{
std::optional<uint32_t> retention_priority;
std::optional<uint64_t> duration_ms;
};
inline void to_json(json& j, const RetentionPriorityAndDuration& r)
{
if (r.retention_priority)
j["retention_priority"] = r.retention_priority.value();
if (r.duration_ms)
j["duration_ms"] = r.duration_ms.value();
}
inline void from_json(const json& j, RetentionPriorityAndDuration& r)
{
if (j.contains("retention_priority"))
r.retention_priority = j.at("retention_priority").get<uint32_t>();
if (j.contains("duration_ms"))
r.duration_ms = j.at("duration_ms").get<uint64_t>();
}
// TokenRangeRetentionConfig Struct
struct TokenRangeRetentionConfig
{
uint32_t token_start;
std::optional<uint32_t> token_end;
uint32_t priority;
std::optional<uint64_t> duration_ms;
};
inline void to_json(json& j, const TokenRangeRetentionConfig& t)
{
j = json{{"token_start", t.token_start}, {"priority", t.priority}};
if (t.token_end)
j["token_end"] = t.token_end.value();
if (t.duration_ms)
j["duration_ms"] = t.duration_ms.value();
}
inline void from_json(const json& j, TokenRangeRetentionConfig& t)
{
j.at("token_start").get_to(t.token_start);
j.at("priority").get_to(t.priority);
if (j.contains("token_end"))
t.token_end = j.at("token_end").get<uint32_t>();
if (j.contains("duration_ms"))
t.duration_ms = j.at("duration_ms").get<uint64_t>();
}
// KvCacheRetentionConfig Struct
struct KvCacheRetentionConfig
{
std::vector<TokenRangeRetentionConfig> token_range_retention_configs;
uint32_t decode_retention_priority;
std::optional<uint64_t> decode_duration_ms;
};
inline void to_json(json& j, const KvCacheRetentionConfig& k)
{
j = json{{"token_range_retention_configs", k.token_range_retention_configs},
{"decode_retention_priority", k.decode_retention_priority}};
if (k.decode_duration_ms)
j["decode_duration_ms"] = k.decode_duration_ms.value();
}
inline void from_json(const json& j, KvCacheRetentionConfig& k)
{
j.at("token_range_retention_configs").get_to(k.token_range_retention_configs);
j.at("decode_retention_priority").get_to(k.decode_retention_priority);
if (j.contains("decode_duration_ms"))
k.decode_duration_ms = j.at("decode_duration_ms").get<uint64_t>();
}
// Request Struct
struct Request
{
std::vector<int32_t> input_token_ids;
uint32_t max_tokens;
bool streaming;
std::optional<SamplingConfig> sampling_config;
std::optional<OutputConfig> output_config;
std::optional<uint32_t> end_id;
// std::optional<uint32_t> pad_id;
// std::vector<uint32_t> position_ids;
// std::vector<uint32_t> bad_words;
// std::vector<uint32_t> stop_words;
// std::vector<uint8_t> embedding_bias; // bytes
// // TODO: Add ExternalDraftTokensConfig external_draft_tokens_config;
// // TODO: Add PromptTuningConfig prompt_tuning_config;
// // TODO: Add LoraConfig lora_config;
// // TODO: Add LookaheadDecodingConfig lookahead_config;
// KvCacheRetentionConfig kv_cache_retention_config;
// std::string logits_post_processor_name;
// std::vector<uint32_t> encoder_input_token_ids;
// std::optional<uint64_t> client_id;
// bool return_all_generated_tokens;
// float priority;
// uint32_t request_type;
// // TODO: Add ContextPhaseParams context_phase_params;
// std::vector<uint8_t> encoder_input_features; // bytes
// std::optional<uint32_t> encoder_output_length;
// std::vector<uint8_t> cross_attention_mask; // bytes
// uint32_t num_return_sequences;
// // TODO: Add EagleConfig eagle_config;
// std::vector<uint8_t> skip_cross_attn_blocks; // bytes
};
// Custom to_json and from_json functions for Request
inline void to_json(json& j, const Request& r)
{
j = json{
{"input_token_ids", r.input_token_ids},
{"max_tokens", r.max_tokens},
{"streaming", r.streaming},
// {"sampling_config", r.sampling_config},
// {"output_config", r.output_config},
// {"position_ids", r.position_ids},
// {"bad_words", r.bad_words},
// {"stop_words", r.stop_words},
// {"kv_cache_retention_config", r.kv_cache_retention_config},
// {"logits_post_processor_name", r.logits_post_processor_name},
// {"encoder_input_token_ids", r.encoder_input_token_ids},
// {"return_all_generated_tokens", r.return_all_generated_tokens},
// {"priority", r.priority},
// {"request_type", r.request_type},
// {"num_return_sequences", r.num_return_sequences}
};
if (r.sampling_config)
j["sampling_config"] = r.sampling_config.value();
if (r.output_config)
j["output_config"] = r.output_config.value();
if (r.end_id)
j["end_id"] = r.end_id.value();
// if (r.pad_id)
// j["pad_id"] = r.pad_id.value();
// if (!r.embedding_bias.empty())
// j["embedding_bias"] = r.embedding_bias;
// if (r.client_id)
// j["client_id"] = r.client_id.value();
// if (!r.encoder_input_features.empty())
// j["encoder_input_features"] = r.encoder_input_features;
// if (r.encoder_output_length)
// j["encoder_output_length"] = r.encoder_output_length.value();
// if (!r.cross_attention_mask.empty())
// j["cross_attention_mask"] = r.cross_attention_mask;
// if (!r.skip_cross_attn_blocks.empty())
// j["skip_cross_attn_blocks"] = r.skip_cross_attn_blocks;
}
inline void from_json(const json& j, Request& r)
{
j.at("input_token_ids").get_to(r.input_token_ids);
j.at("max_tokens").get_to(r.max_tokens);
j.at("streaming").get_to(r.streaming);
if (j.contains("sampling_config"))
r.sampling_config = j.at("sampling_config").get<SamplingConfig>();
if (j.contains("output_config"))
r.output_config = j.at("output_config").get<OutputConfig>();
// j.at("sampling_config").get_to(r.sampling_config);
// j.at("output_config").get_to(r.output_config);
// j.at("position_ids").get_to(r.position_ids);
// j.at("bad_words").get_to(r.bad_words);
// j.at("stop_words").get_to(r.stop_words);
// j.at("kv_cache_retention_config").get_to(r.kv_cache_retention_config);
// j.at("logits_post_processor_name").get_to(r.logits_post_processor_name);
// j.at("encoder_input_token_ids").get_to(r.encoder_input_token_ids);
// j.at("return_all_generated_tokens").get_to(r.return_all_generated_tokens);
// j.at("priority").get_to(r.priority);
// j.at("request_type").get_to(r.request_type);
// j.at("num_return_sequences").get_to(r.num_return_sequences);
if (j.contains("end_id"))
r.end_id = j.at("end_id").get<uint32_t>();
// if (j.contains("pad_id"))
// r.pad_id = j.at("pad_id").get<uint32_t>();
// if (j.contains("embedding_bias"))
// r.embedding_bias = j.at("embedding_bias").get<std::vector<uint8_t>>();
// if (j.contains("client_id"))
// r.client_id = j.at("client_id").get<uint64_t>();
// if (j.contains("encoder_input_features"))
// r.encoder_input_features = j.at("encoder_input_features").get<std::vector<uint8_t>>();
// if (j.contains("encoder_output_length"))
// r.encoder_output_length = j.at("encoder_output_length").get<uint32_t>();
// if (j.contains("cross_attention_mask"))
// r.cross_attention_mask = j.at("cross_attention_mask").get<std::vector<uint8_t>>();
// if (j.contains("skip_cross_attn_blocks"))
// r.skip_cross_attn_blocks = j.at("skip_cross_attn_blocks").get<std::vector<uint8_t>>();
}
tensorrt_llm::executor::Request deserialize_request(const std::string& request_proto)
{
spdlog::trace("Deserializing request json: {}", request_proto);
auto j = json::parse(request_proto);
auto req_in = j.get<Request>();
spdlog::trace("constructing request with {} input tokens; max tokens: {}",
req_in.input_token_ids.size(),
req_in.max_tokens);
tensorrt_llm::executor::Request request(std::move(req_in.input_token_ids), req_in.max_tokens, true);
if (req_in.sampling_config)
{
spdlog::trace("Setting sampling_config");
request.setSamplingConfig(req_in.sampling_config->to_executor_config());
}
if (req_in.end_id)
{
spdlog::trace("Setting end_id: {}", req_in.end_id.value());
request.setEndId(req_in.end_id.value());
}
return request;
}
} // namespace nvidia::nvllm::trt
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tensorrt_llm/executor/executor.h"
namespace nvidia::nvllm::trt {
tensorrt_llm::executor::Request deserialize_request(const std::string& request);
} // namespace nvidia::nvllm::trt
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "engine_trt/response.hpp"
#include <nlohmann/json.hpp>
#include <spdlog/spdlog.h>
#include <optional>
#include <string>
#include <vector>
using json = nlohmann::json;
namespace ex = tensorrt_llm::executor;
namespace nvidia::nvllm::trt {
// Forward declarations
struct Response;
struct Output;
enum FinishReasonEnum
{
FINISH_REASON_NOT_DONE = 0,
FINISH_REASON_EOS = 1,
FINISH_REASON_STOP = 2,
FINISH_REASON_LENGTH = 3,
};
// Output Struct
struct Output
{
bool is_final;
std::vector<int32_t> token_ids;
std::optional<float> cum_log_prob;
std::optional<std::vector<float>> log_probs;
std::optional<FinishReasonEnum> finish_reason;
};
// Custom to_json function
void to_json(json& j, const Output& o)
{
j = json{{"is_final", o.is_final}, {"token_ids", o.token_ids}};
if (o.cum_log_prob)
{
j["cum_log_prob"] = *o.cum_log_prob;
}
if (o.log_probs)
{
j["log_probs"] = *o.log_probs;
}
if (o.finish_reason)
{
j["finish_reason"] = static_cast<int>(*o.finish_reason);
}
}
void from_json(const json& j, Output& o)
{
j.at("is_final").get_to(o.is_final);
j.at("token_ids").get_to(o.token_ids);
if (j.contains("cum_log_prob") && !j["cum_log_prob"].is_null())
{
o.cum_log_prob = j["cum_log_prob"].get<float>();
}
else
{
o.cum_log_prob = std::nullopt;
}
if (j.contains("log_probs") && !j["log_probs"].is_null())
{
o.log_probs = j["log_probs"].get<std::vector<float>>();
}
else
{
o.log_probs = std::nullopt;
}
if (j.contains("finish_reason") && !j["finish_reason"].is_null())
{
o.finish_reason = static_cast<FinishReasonEnum>(j["finish_reason"].get<int>());
}
else
{
o.finish_reason = std::nullopt;
}
}
// Response Struct
struct Response
{
uint64_t request_id;
std::optional<uint64_t> client_id; // Optional client ID.
std::optional<std::string> error_msg;
std::optional<Output> output;
};
inline void to_json(json& j, const Response& p)
{
j = json{{"request_id", p.request_id}};
if (p.client_id)
j["client_id"] = p.client_id.value();
if (p.error_msg)
j["error_msg"] = p.error_msg.value();
if (p.output)
j["output"] = p.output.value();
}
inline void from_json(const json& j, Response& p)
{
j.at("request_id").get_to(p.request_id);
if (j.contains("client_id"))
p.client_id = j.at("client_id").get<uint64_t>();
if (j.contains("error_msg"))
p.error_msg = j.at("error_msg").get<std::string>();
if (j.contains("output"))
p.output = j.at("output").get<Output>();
}
// Responses Struct
struct Responses
{
std::vector<Response> responses;
bool shutdown;
};
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Responses, responses, shutdown)
Response convert(ex::Response&& response)
{
auto request_id = response.getRequestId();
auto client_id = response.getClientId();
if (response.hasError())
{
auto error_msg = response.getErrorMsg();
return Response{request_id, client_id, {error_msg}, std::nullopt};
}
auto e_output = response.getResult();
auto is_final = e_output.isFinal;
assert(e_output.outputTokenIds.size() == 1);
auto token_ids = std::move(e_output.outputTokenIds[0]);
auto output = Output{is_final, std::move(token_ids), std::nullopt, std::nullopt, std::nullopt};
if (e_output.cumLogProbs.has_value())
{
assert(e_output.cumLogProbs.value().size() == 1);
output.cum_log_prob = {e_output.cumLogProbs.value()[0]};
}
if (e_output.logProbs.has_value())
{
assert(e_output.logProbs.value().size() == 1);
output.log_probs = {std::move(e_output.logProbs.value()[0])};
}
if (e_output.finishReasons.size() > 0)
{
assert(e_output.finishReasons.size() == 1);
auto finish_reason = static_cast<FinishReasonEnum>(e_output.finishReasons[0]);
if (finish_reason != FinishReasonEnum::FINISH_REASON_NOT_DONE)
{
output.finish_reason = {finish_reason};
}
}
return Response{request_id, client_id, std::nullopt, {output}};
}
std::string serialize_responses(std::deque<ex::Response> responses, bool shutdown)
{
auto object = Responses{};
object.shutdown = shutdown;
while (!responses.empty())
{
auto response = std::move(responses.front());
responses.pop_front();
auto r = convert(std::move(response));
assert(r.output.has_value() || r.error_msg.has_value());
object.responses.emplace_back(std::move(r));
}
return json(object).dump();
}
} // namespace nvidia::nvllm::trt
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tensorrt_llm/executor/executor.h"
namespace nvidia::nvllm::trt {
std::string serialize_responses(std::deque<tensorrt_llm::executor::Response> responses, bool shutdown);
} // namespace nvidia::nvllm::trt
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "engine_trt/stats.hpp"
#include <nlohmann/json.hpp>
#include <deque>
using json = nlohmann::json;
namespace nvidia::nvllm::trt {
std::string serialize_iter_stats(std::deque<::tensorrt_llm::executor::IterationStats> stats)
{
json json_stats = json::array();
for (const auto& stat : stats)
{
if (stat.kvCacheStats.has_value())
{
json entry;
entry["iter"] = stat.iter;
entry["kv_active_blocks"] = stat.kvCacheStats->usedNumBlocks;
entry["kv_total_blocks"] = stat.kvCacheStats->maxNumBlocks;
entry["request_active_slots"] = stat.numActiveRequests;
entry["request_total_slots"] = stat.maxNumActiveRequests;
entry["request_new_active_slots"] = stat.numNewActiveRequests;
json_stats.push_back(entry);
}
else
{
json entry;
entry["iter"] = stat.iter;
entry["request_active_slots"] = stat.numActiveRequests;
entry["request_total_slots"] = stat.maxNumActiveRequests;
entry["request_new_active_slots"] = stat.numNewActiveRequests;
json_stats.push_back(entry);
}
}
return json_stats.dump();
}
} // namespace nvidia::nvllm::trt
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tensorrt_llm/executor/executor.h"
namespace nvidia::nvllm::trt {
std::string serialize_iter_stats(std::deque<tensorrt_llm::executor::IterationStats> stats);
} // namespace nvidia::nvllm::trt
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nvidia/nvllm/nvllm_trt.h"
#include "api/engine.hpp"
#include <cstring>
extern "C" {
// int trtllm_mpi_session_set_communicator(void* world_comm_ptr)
// {
// return nvidia::nvllm::trt::MpiSession::set_communicator(world_comm_ptr);
// }
nvllm_trt_engine_t nvllm_trt_engine_create(const char* config_proto)
{
// based on the type of engine, we might choose to create a different concrete engine object
try
{
return reinterpret_cast<nvllm_trt_engine_t>(new nvidia::nvllm::trt::StreamingEngine(std::string(config_proto)));
} catch (const std::exception& e)
{
printf("Caught exception when initializing tensorrt_llm: %s\n", e.what());
return nullptr;
}
}
nvllm_trt_engine_t nvllm_trt_engine_unsafe_create_from_executor(void* engine)
{
try
{
return reinterpret_cast<nvllm_trt_engine_t>(new nvidia::nvllm::trt::StreamingEngine(engine));
} catch (const std::exception& e)
{
printf("Caught exception when initializing from raw pointer: %s\n", e.what());
return nullptr;
}
}
request_id_t nvllm_trt_engine_enqueue_request(nvllm_trt_engine_t engine, client_id_t client_id, const char* req_proto)
{
// Call the enqueue_request method on the C++ class
try
{
return reinterpret_cast<nvidia::nvllm::trt::StreamingEngine*>(engine)->enqueue_request(client_id,
std::string(req_proto));
} catch (...)
{
return 0;
}
}
char* nvllm_trt_engine_await_responses(nvllm_trt_engine_t engine)
{
auto responses = reinterpret_cast<nvidia::nvllm::trt::StreamingEngine*>(engine)->await_responses();
char* c_responses = strdup(responses.c_str()); // Allocate memory and copy the string
return c_responses; // Return the C string (remember to free this in the calling code)
}
char* nvllm_trt_engine_await_kv_events(nvllm_trt_engine_t engine)
{
auto responses = reinterpret_cast<nvidia::nvllm::trt::StreamingEngine*>(engine)->await_kv_events();
if (!responses)
{
return nullptr;
}
char* c_responses = strdup(responses->c_str()); // Allocate memory and copy the string
return c_responses; // Return the C string (remember to free this in the calling code)
}
// Get basic iteration stats
char* nvllm_trt_engine_await_iter_stats(nvllm_trt_engine_t engine)
{
auto responses = reinterpret_cast<nvidia::nvllm::trt::StreamingEngine*>(engine)->await_iter_stats();
if (!responses)
{
return nullptr;
}
char* c_responses = strdup(responses->c_str());
return c_responses;
}
void nvllm_trt_engine_free_responses(char* responses)
{
free(responses);
}
void nvllm_trt_engine_cancel_request(nvllm_trt_engine_t engine, uint64_t request_id)
{
reinterpret_cast<nvidia::nvllm::trt::StreamingEngine*>(engine)->cancel_request(request_id);
}
void nvllm_trt_engine_shutdown(nvllm_trt_engine_t engine)
{
reinterpret_cast<nvidia::nvllm::trt::StreamingEngine*>(engine)->shutdown();
}
int nvllm_trt_engine_destroy(nvllm_trt_engine_t engine)
{
auto* trtllm_engine = reinterpret_cast<nvidia::nvllm::trt::StreamingEngine*>(engine);
delete trtllm_engine;
return NVLLM_TRT_ENGINE_SUCCESS;
}
int nvllm_trt_engine_is_ready(nvllm_trt_engine_t engine)
{
return reinterpret_cast<nvidia::nvllm::trt::StreamingEngine*>(engine)->is_ready();
}
int nvllm_trt_engine_has_completed(nvllm_trt_engine_t engine)
{
return reinterpret_cast<nvidia::nvllm::trt::StreamingEngine*>(engine)->has_completed();
}
// int trtllm_version_major()
// {
// return TRTLLM_VERSION_MAJOR;
// }
// int trtllm_version_minor()
// {
// return TRTLLM_VERSION_MINOR;
// }
// int trtllm_version_patch()
// {
// return TRTLLM_VERSION_PATCH;
// }
}
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
[package]
name = "dynamo-engine-trtllm"
version.workspace = true
edition.workspace = true
description.workspace = true
authors.workspace = true
license.workspace = true
homepage.workspace = true
repository.workspace = true
keywords.workspace = true
[dependencies]
dynamo-runtime = { workspace = true }
dynamo-llm = { workspace = true }
anyhow = { workspace = true }
async-stream = { workspace = true }
async-trait = { workspace = true }
derive_builder = {workspace = true }
futures = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true }
tokio-util = { workspace = true }
tracing = { workspace = true }
async-openai = "0.27.2"
serde_repr = "0.1"
[build-dependencies]
bindgen = "0.70"
cmake = "0.1"
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
fn main() {
extern crate bindgen;
use cmake::Config;
use std::env;
use std::path::PathBuf;
let installed_headers = "/usr/local/include/nvidia/nvllm/nvllm_trt.h";
let local_headers = "../../bindings/cpp/nvllm-trt/include/nvidia/nvllm/nvllm_trt.h";
let headers_path;
if PathBuf::from(installed_headers).exists() {
headers_path = installed_headers;
println!("cargo:warning=nvllm found. Building with installed version...");
println!("cargo:rustc-link-search=native=/usr/local/lib");
println!("cargo:rustc-link-search=native=/opt/tensorrt_llm/lib");
println!("cargo:rustc-link-lib=dylib=nvllm_trt");
println!("cargo:rustc-link-lib=dylib=tensorrt_llm");
println!("cargo:rustc-link-lib=dylib=tensorrt_llm_nvrtc_wrapper");
println!("cargo:rustc-link-lib=dylib=nvinfer_plugin_tensorrt_llm");
println!("cargo:rustc-link-lib=dylib=decoder_attention");
println!("cargo:rerun-if-changed=/usr/local/lib");
} else if PathBuf::from(local_headers).exists() {
headers_path = local_headers;
println!("cargo:warning=nvllm not found. Building stub version...");
let dst = Config::new("../../bindings/cpp/nvllm-trt")
.define("USE_STUBS", "ON")
.no_build_target(true)
.build();
println!("cargo:warning=building stubs in {}", dst.display());
let dst = dst.canonicalize().unwrap();
println!("cargo:rustc-link-search=native={}/build", dst.display());
println!("cargo:rustc-link-lib=dylib=nvllm_trt");
println!("cargo:rustc-link-lib=dylib=tensorrt_llm");
println!("cargo:rerun-if-changed=../bindings/cpp/nvllm-trt");
} else {
panic!("nvllm_trt.h not found");
}
// generate bindings for the trtllm c api
let bindings = bindgen::Builder::default()
.header(headers_path)
.generate()
.expect("Unable to generate bindings");
// Write the bindings to a file
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
bindings
.write_to_file(out_path.join("bindings.rs"))
.expect("Could not write bindings!");
// // Build protobuf
// tonic_build::configure()
// .build_server(false)
// .compile_protos(&["../../proto/trtllm.proto"], &["../../proto"])
// .expect("Failed to compile protos");
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
mod cpp;
mod engine;
mod processors;
// pub mod protos {
// include!(concat!(env!("OUT_DIR"), "/nvidia.nvllm.trt.proto.rs"));
// }
pub mod protocols;
pub mod config;
use anyhow::Result;
use std::{
collections::HashMap,
ffi::CString,
sync::{atomic::AtomicU64, Arc, Mutex, OnceLock, Weak},
};
use tokio::sync::mpsc;
use processors::{
IterationProcessor, IterationStatsSubscriptionChannel, KvEventProcessor,
KvEventSubscriptionChannel, ProcessorState, ResponseProcessor,
};
pub struct Executor {
executor: Arc<cpp::Executor>,
next_id: AtomicU64,
response_queues: ResponseQueues,
response_processor: OnceLock<ResponseProcessor>,
kv_event_processor: OnceLock<KvEventProcessor>,
iteration_processor: OnceLock<IterationProcessor>,
}
type ResponseQueues = Arc<Mutex<HashMap<u64, mpsc::Sender<Result<protocols::Output>>>>>;
impl Executor {
pub fn from_model_path<P: ToString>(model_path: P) -> Result<Self> {
let config = config::ExecutorConfig::new(model_path.to_string());
Self::new(config)
}
pub fn new(config: config::ExecutorConfig) -> Result<Self> {
Ok(Self {
executor: Arc::new(cpp::Executor::new(config)?),
next_id: AtomicU64::new(0),
response_queues: Arc::new(Mutex::new(HashMap::new())),
response_processor: OnceLock::new(),
kv_event_processor: OnceLock::new(),
iteration_processor: OnceLock::new(),
})
}
pub fn has_started(&self) -> bool {
self.executor.has_started()
}
pub fn has_completed(&self) -> bool {
self.executor.has_completed()
}
pub fn enqueue_request(&self, request: protocols::Request) -> Result<ExecutionContext> {
let client_id = self
.next_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let (tx, rx) = mpsc::channel(128);
self.response_queues
.lock()
.expect("response_queues lock poisoned")
.insert(client_id, tx);
let json = serde_json::to_string(&request)?;
let str = CString::new(json)?;
let request_id = self
.executor
.enqueue_request(client_id, str)
.inspect_err(|_| {
self.response_queues
.lock()
.expect("response_queues lock poisoned")
.remove(&client_id);
})?;
println!("request_id: {}", request_id);
Ok(ExecutionContext {
request_id,
response_rx: Some(rx),
executor: Arc::downgrade(&self.executor),
})
}
pub fn cancel_request(&self, client_id: u64) {
self.executor.cancel_request(client_id)
}
/// Start a background task to process responses from the TensorRT LLM AsyncEngine
pub fn start_response_processor(&self) {
self.response_processor.get_or_init(|| {
ResponseProcessor::new(self.create_processor(), self.response_queues.clone())
});
}
/// Starts a background task to process kv events
/// TODO - check the TensorRT LLM config and only start this if the server is configured to send kv events
pub fn start_kv_event_processor(&self) {
self.kv_event_processor
.get_or_init(|| KvEventProcessor::new(self.create_processor()));
}
/// Starts a background task to process forward pass / iteration statistics
pub fn start_iteration_metrics_processor(&self) {
self.iteration_processor
.get_or_init(|| IterationProcessor::new(self.create_processor()));
}
/// Subscribes to the KV Events broadcast channel
pub fn subscribe_to_kv_events(&self) -> Result<KvEventSubscriptionChannel> {
self.kv_event_processor
.get_or_init(|| KvEventProcessor::new(self.create_processor()))
.subscribe()
.ok_or(anyhow::anyhow!("Failed to subscribe to KV events"))
}
pub fn subscribe_to_iteration_stats(&self) -> Result<IterationStatsSubscriptionChannel> {
self.iteration_processor
.get_or_init(|| IterationProcessor::new(self.create_processor()))
.subscribe()
.ok_or(anyhow::anyhow!("Failed to subscribe to iteration stats"))
}
/// Issues a shutdown request to the TensorRT LLM AsyncEngine
/// This is a blocking call. After the async engine has shutdown each background processor/thread/task
/// will be joined and the resources will be released.
pub fn shutdown(&mut self) {
self.executor.shutdown();
self.response_processor.take().map(|p| p.join());
self.kv_event_processor.take().map(|p| p.join());
self.iteration_processor.take().map(|p| p.join());
}
/// Constructs a new ProcessorState instance which packages up any bits from the Executor for the processor task
fn create_processor(&self) -> ProcessorState {
ProcessorState::new(self.executor.clone())
}
}
impl Drop for Executor {
fn drop(&mut self) {
self.shutdown();
}
}
pub struct ExecutionContext {
/// Internal TensorRT LLM request_id; used to cancel the request
/// This value is present in the response but because we do not know it before hand, it is only used for cancellation
request_id: u64,
/// Hold a weak pointer to the executor for cancellation
executor: Weak<cpp::Executor>,
/// Response stream associated with this request
response_rx: Option<mpsc::Receiver<Result<protocols::Output>>>,
}
impl ExecutionContext {
pub fn cancel(&self) {
if let Some(executor) = self.executor.upgrade() {
executor.cancel_request(self.request_id);
}
}
pub fn take_response_rx(&mut self) -> Option<mpsc::Receiver<Result<protocols::Output>>> {
self.response_rx.take()
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, Default, Builder)]
pub struct ExecutorConfig {
model_path: String,
#[builder(default = "LogLevel::Error")]
log_level: LogLevel,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default)]
enable_chunked_context: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default)]
normalize_log_probs: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default)]
iter_stats_max_iterations: Option<u32>,
/// The number of processes for tensor parallelism. Defaults to 1.
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default)]
tensor_parallel_size: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum LogLevel {
#[default]
Error,
Warn,
Info,
Debug,
Trace,
}
impl From<&str> for LogLevel {
fn from(value: &str) -> Self {
match value.to_lowercase().as_str() {
"error" => LogLevel::Error,
"warn" => LogLevel::Warn,
"info" => LogLevel::Info,
"debug" => LogLevel::Debug,
"trace" => LogLevel::Trace,
_ => LogLevel::default(), // Default to Error if no match
}
}
}
impl ExecutorConfig {
pub fn builder() -> ExecutorConfigBuilder {
ExecutorConfigBuilder::default()
}
pub fn new(model_path: String) -> Self {
Self {
model_path,
log_level: LogLevel::Error,
enable_chunked_context: None,
normalize_log_probs: None,
iter_stats_max_iterations: None,
tensor_parallel_size: None,
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::{Context, Error, Result};
use bindings::nvllm_trt_engine_destroy;
use std::ffi::CStr;
use std::ffi::CString;
use std::ptr::NonNull;
use super::protocols;
use dynamo_llm::kv_router::protocols::{ForwardPassMetrics, KvCacheEvents};
mod bindings {
#![allow(warnings, missing_docs)]
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
}
use bindings::{
nvllm_trt_engine, nvllm_trt_engine_await_iter_stats, nvllm_trt_engine_await_kv_events,
nvllm_trt_engine_await_responses, nvllm_trt_engine_cancel_request, nvllm_trt_engine_create,
nvllm_trt_engine_enqueue_request, nvllm_trt_engine_free_responses,
nvllm_trt_engine_has_completed, nvllm_trt_engine_is_ready, nvllm_trt_engine_shutdown,
};
use super::config;
#[derive(Debug, Clone)]
pub struct Executor {
engine: NonNull<nvllm_trt_engine>,
}
// nvllm_trt_engine is thread safe
// rust does not know that it is thread safe, so we have to tell it
unsafe impl Send for Executor {}
unsafe impl Sync for Executor {}
// The following implementation of ThreaadSafeEngine are the convenience methods used for call
// the C/C++ TensorRT API from Rust.
impl Executor {
/// Creates a new instance of the TensorRT LLM engine and takes ownership of the pointer to
/// the C/C++ TensorRT LLM engine object.
///
/// Executor implements the Drop trait, so this object is an RAII object and will
/// free the C/C++ TensorRT LLM engine object when it goes out of scope.
pub fn new(config: config::ExecutorConfig) -> Result<Self> {
let json = serde_json::to_string(&config)?;
let c_config = CString::new(json)?;
let engine = unsafe { nvllm_trt_engine_create(c_config.as_ptr()) };
let engine = NonNull::new(engine)
.ok_or_else(|| Error::msg("Failed to create nvllm_trt_engine".to_string()))?;
Ok(Self { engine })
}
/// Checks if the engine has started asking for new work
pub fn has_started(&self) -> bool {
let result = unsafe { nvllm_trt_engine_is_ready(self.engine.as_ptr()) };
if result != 0 {
return true;
}
false
}
/// Checks if the engine has completed all work and shutdown
pub fn has_completed(&self) -> bool {
let result = unsafe { nvllm_trt_engine_has_completed(self.engine.as_ptr()) };
if result != 0 {
return true;
}
false
}
/// Enqueues a request to the engine
/// The request it sent to the engine as a json encoded string; however, we reserve the right to change
/// the encoding in the future.
pub fn enqueue_request(&self, client_id: u64, request: CString) -> Result<u64> {
tracing::trace!("enqueuing request to trtllm engine");
let id = unsafe {
nvllm_trt_engine_enqueue_request(self.engine.as_ptr(), client_id, request.as_ptr())
};
if id == 0 {
return Err(Error::msg("Failed to enqueue request".to_string()));
}
Ok(id)
}
/// Block on [`nvllm_trt_engine_await_responses`] until a set response is received
/// If the server shutdown, the list of Responses will be empty
pub fn await_responses(&self) -> Result<protocols::Responses> {
let responses;
unsafe {
let ptr = nvllm_trt_engine_await_responses(self.engine.as_ptr());
let c_str = CStr::from_ptr(ptr);
let bytes = c_str.to_bytes();
responses = serde_json::from_slice(bytes).context("Failed to parse responses")?;
nvllm_trt_engine_free_responses(ptr);
}
Ok(responses)
}
pub fn await_kv_events(&self) -> Result<KvCacheEvents> {
let events: KvCacheEvents;
unsafe {
let ptr = nvllm_trt_engine_await_kv_events(self.engine.as_ptr());
if ptr.is_null() {
return Err(Error::msg(
"No KvEvents will be emitted for this model".to_string(),
));
}
let c_str = CStr::from_ptr(ptr);
let bytes = c_str.to_bytes();
events = serde_json::from_slice(bytes)
.context(format!("Failed to parse kv cache events: {:?}", c_str))?;
nvllm_trt_engine_free_responses(ptr);
}
Ok(events)
}
#[allow(dead_code)]
pub fn await_iter_stats(&self) -> Result<protocols::stats::IterStats> {
let stats: Vec<ForwardPassMetrics>;
unsafe {
let ptr = nvllm_trt_engine_await_iter_stats(self.engine.as_ptr());
if ptr.is_null() {
return Err(Error::msg(
"No iter stats will be emitted for this model".to_string(),
));
}
let c_str = CStr::from_ptr(ptr);
let bytes = c_str.to_bytes();
stats = serde_json::from_slice(bytes)
.context(format!("Failed to parse iter stats: {:?}", c_str))?;
nvllm_trt_engine_free_responses(ptr);
}
let stats = protocols::stats::IterStats { stats };
Ok(stats)
}
/// Cancels a request by its request_id
pub fn cancel_request(&self, request_id: u64) {
unsafe { nvllm_trt_engine_cancel_request(self.engine.as_ptr(), request_id) };
}
/// Shuts down the engine
pub fn shutdown(&self) {
unsafe { nvllm_trt_engine_shutdown(self.engine.as_ptr()) };
}
}
impl Drop for Executor {
fn drop(&mut self) {
unsafe {
nvllm_trt_engine_shutdown(self.engine.as_ptr());
nvllm_trt_engine_destroy(self.engine.as_ptr());
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::{Error, Result};
use async_trait::async_trait;
use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use dynamo_runtime::pipeline::{ManyOut, SingleIn};
use dynamo_runtime::protocols::annotated::Annotated;
use futures::stream;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use dynamo_llm::protocols::common::llm_backend::{BackendInput, LLMEngineOutput};
use super::Executor;
struct State {
request_id: String,
cancel_token: CancellationToken,
response_rx: mpsc::Receiver<Result<super::protocols::Output>>,
_link_to_cancel_task: tokio::sync::oneshot::Receiver<()>,
// set to true if we send what we expect to be a final message
// if the engine's response stream is closed before we send a final message, we can
// detect that condition and report an unknown error engine stream termination event
sentinel: bool,
}
// impl Drop for State {
// fn drop(&mut self) {
// tracing::trace!(request_id = self.stream.id(), "dropping state");
// }
// }
#[async_trait]
impl AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<LLMEngineOutput>>, Error> for Executor {
async fn generate(
&self,
request: SingleIn<BackendInput>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
// unpack the request and context
let (request, context) = request.into_parts();
// grab the core context
let context = context.context();
let context_cloned = context.clone();
// create a cancellation token and request id
let cancel_token = CancellationToken::new();
let request_id = context.id().to_string();
let mut engine_context = self.enqueue_request(request.into())?;
let (mut tx, rx) = tokio::sync::oneshot::channel::<()>();
let state = State {
request_id,
cancel_token: cancel_token.clone(),
_link_to_cancel_task: rx,
response_rx: engine_context
.take_response_rx()
.ok_or(Error::msg("no response rx"))?,
sentinel: false,
};
// create a task to monitor the the requests cancellation state
// todo: spawn on low priority async thread pool
tokio::spawn(async move {
tokio::select! {
_ = context.stopped() => {
tracing::debug!(request_id = context.id(), "request cancelled");
engine_context.cancel();
cancel_token.cancel();
}
_ = tx.closed() => {
tracing::debug!(request_id = context.id(), "response stream closed");
}
}
});
// create the response stream
let stream = stream::unfold(state, |mut state| async move {
if state.sentinel {
tracing::debug!(
request_id = state.request_id,
"sentinel set, closing stream"
);
return None;
}
// let output = tokio::select! {
let output = tokio::select! {
biased;
// await a response from the trtllm engine's response processor
output = state.response_rx.recv() => {
output
}
// if the stream is stopped, we need to:
// - cancel the request on the trtll engine
// - return an output with a finish reason of cancelled
// - mark the state as completed by setting the sentinel to true
_ = state.cancel_token.cancelled() => {
tracing::debug!(request_id = state.request_id, "request cancelled");
// state.engine.cancel();
state.sentinel = true;
let output = LLMEngineOutput::cancelled();
return Some((Annotated::from_data(output), state))
}
};
match output {
Some(Ok(output)) => {
if output.is_final {
tracing::debug!(request_id = state.request_id, "final response");
state.sentinel = true;
}
tracing::trace!(request_id = state.request_id, "issue response");
let output = LLMEngineOutput::from(output);
Some((Annotated::from_data(output), state))
}
Some(Err(err)) => {
tracing::debug!(request_id = state.request_id, "request failed: {:?}", err);
state.sentinel = true;
Some((Annotated::from_error(err.to_string()), state))
}
None => {
tracing::debug!(request_id = state.request_id, "request completed");
if !state.sentinel {
tracing::warn!(
request_id = state.request_id,
"engine stream terminated before final response or error"
);
state.sentinel = true;
Some((
Annotated::<LLMEngineOutput>::from_error(
"engine stream terminated before final response".to_string(),
),
state,
))
} else {
None
}
}
}
});
Ok(ResponseStream::new(Box::pin(stream), context_cloned))
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::{cpp, protocols};
use anyhow::Result;
use std::sync::Arc;
mod iteration;
mod kv;
mod response;
pub use iteration::{IterationProcessor, SubscriptionChannel as IterationStatsSubscriptionChannel};
pub use kv::{KvEventProcessor, KvEventSubscriptionChannel};
pub use response::ResponseProcessor;
#[derive(Debug)]
pub(crate) struct ProcessorState {
executor: Arc<cpp::Executor>,
}
impl ProcessorState {
pub fn new(executor: Arc<cpp::Executor>) -> Self {
Self { executor }
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use dynamo_llm::kv_router::protocols::ForwardPassMetrics;
use std::{
sync::{
atomic::{AtomicBool, Ordering},
Arc, Weak,
},
thread,
};
use tokio::sync::broadcast;
use super::*;
const CHANNEL_CAPACITY: usize = 256;
type ChannelType = broadcast::Sender<Arc<ForwardPassMetrics>>;
pub type SubscriptionChannel = broadcast::Receiver<Arc<ForwardPassMetrics>>;
pub struct IterationProcessor {
handle: thread::JoinHandle<()>,
shutdown: Arc<AtomicBool>,
channel: Weak<ChannelType>,
}
impl IterationProcessor {
/// Creates a new KV Event Processor
pub fn new(state: ProcessorState) -> Self {
// Shutdown Token
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_clone = shutdown.clone();
// Event Channel
let channel = Arc::new(broadcast::channel(CHANNEL_CAPACITY).0);
let channel_clone = channel.clone();
let handle = std::thread::spawn(move || {
process_events(state, shutdown_clone, channel_clone);
});
IterationProcessor {
handle,
shutdown,
channel: Arc::downgrade(&channel),
}
}
/// Subscribes to the KV Events broadcast channel
/// Multiple subscribers can be created to monitor the KV Events
pub fn subscribe(&self) -> Option<SubscriptionChannel> {
self.channel.upgrade().map(|channel| channel.subscribe())
}
/// Joins the thread and waits for it to finish
pub fn join(self) -> thread::Result<()> {
self.shutdown.store(true, Ordering::Relaxed);
self.handle.join()
}
}
fn process_events(state: ProcessorState, shutdown: Arc<AtomicBool>, channel: Arc<ChannelType>) {
loop {
// this blocks the thread until the response is ready or the server is shutdown
let iters = state
.executor
.await_iter_stats()
.expect("Failed to await responses");
let should_shutdown = shutdown.load(Ordering::Relaxed);
for iter in iters.stats {
tracing::debug!("Received iteration stats: {:?}", iter);
let iter = Arc::new(iter);
if let Err(e) = channel.send(iter) {
tracing::debug!("Failed to send message to channel: {:?}", e);
break;
}
}
if should_shutdown {
tracing::debug!("Shutting down KV Event Processor");
break;
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use dynamo_llm::kv_router::protocols::KvCacheEvents;
use std::{
sync::{
atomic::{AtomicBool, Ordering},
Arc, Weak,
},
thread,
};
use tokio::sync::broadcast;
use super::*;
const KV_EVENT_CHANNEL_CAPACITY: usize = 65536;
type EventChannelType = broadcast::Sender<KvCacheEvents>;
pub type KvEventSubscriptionChannel = broadcast::Receiver<KvCacheEvents>;
pub struct KvEventProcessor {
handle: thread::JoinHandle<()>,
shutdown: Arc<AtomicBool>,
channel: Weak<EventChannelType>,
}
impl KvEventProcessor {
/// Creates a new KV Event Processor
pub fn new(state: ProcessorState) -> Self {
// Shutdown Token
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_clone = shutdown.clone();
// Event Channel
let channel = Arc::new(broadcast::channel(KV_EVENT_CHANNEL_CAPACITY).0);
let channel_clone = channel.clone();
let handle = std::thread::spawn(move || {
process_events(state, shutdown_clone, channel_clone);
});
KvEventProcessor {
handle,
shutdown,
channel: Arc::downgrade(&channel),
}
}
/// Subscribes to the KV Events broadcast channel
/// Multiple subscribers can be created to monitor the KV Events
pub fn subscribe(&self) -> Option<broadcast::Receiver<KvCacheEvents>> {
self.channel.upgrade().map(|channel| channel.subscribe())
}
/// Joins the thread and waits for it to finish
pub fn join(self) -> thread::Result<()> {
self.shutdown.store(true, Ordering::Relaxed);
self.handle.join()
}
}
fn process_events(
state: ProcessorState,
shutdown: Arc<AtomicBool>,
channel: Arc<EventChannelType>,
) {
loop {
// this blocks the thread until the response is ready or the server is shutdown
let mut message = state
.executor
.await_kv_events()
.expect("Failed to await responses");
let should_shutdown = message.shutdown || shutdown.load(Ordering::Relaxed);
message.shutdown = should_shutdown;
if let Err(e) = channel.send(message) {
tracing::debug!("Failed to send message to channel: {:?}", e);
}
if should_shutdown {
tracing::debug!("Shutting down KV Event Processor");
break;
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::thread;
use tokio::sync::mpsc;
use super::*;
use crate::executor::ResponseQueues;
pub struct ResponseProcessor {
handle: thread::JoinHandle<()>,
}
impl ResponseProcessor {
pub fn new(state: ProcessorState, response_queues: ResponseQueues) -> Self {
let handle = std::thread::spawn(move || {
process_responses(state, response_queues);
});
ResponseProcessor { handle }
}
/// Block and wait for the response processor to finish
pub fn join(self) -> thread::Result<()> {
self.handle.join()
}
}
#[derive(Debug, thiserror::Error)]
enum ResponseError {
#[error("Response queue dropped; possible client disconnect")]
ResponseQueueDropped,
#[error("Response channel closed; possible client disconnect")]
ChannelClosed,
#[error("Response channel full; backpress detected in response stream")]
ChannelFull,
#[error("Invalid response: no error or result found")]
InvalidResponse,
/// Error indicating that TensorRT LLM returned an error
/// This also indicates that the request was not successful and no further responses
/// will be sent for this request
#[error("TensorRT LLM Engine Error: {0}")]
EngineError(String),
#[error("Completed successfully")]
RequestComplete,
}
fn process_responses(state: ProcessorState, response_queues: ResponseQueues) {
loop {
// this blocks the thread until the response is ready or the server is shutdown
let message = state
.executor
.await_responses()
.expect("Failed to await responses");
// check shutdown condition
if message.shutdown {
tracing::info!("Server shutdown detected");
break;
}
// process responses - hold the lock while we iterate to avoid any contention
// grabbing and releasing it for each response
let mut queues = response_queues.lock().unwrap();
for output in message.responses {
let request_id = output.request_id;
let client_id = output.client_id.expect("client_id is missing");
let tx = queues.get(&client_id);
match try_send(tx, output) {
Ok(_) => {}
Err(e) => {
tracing::trace!(client_id, "processing response: {}", e);
match e {
ResponseError::InvalidResponse => {
// this would likely be a bug on the server; we expect the oneof to be set
tracing::warn!(client_id, "Invalid response; No action required");
}
ResponseError::EngineError(_) => {
// no need to cancel, the server will not send any more responses
queues.remove(&client_id);
}
ResponseError::ChannelFull => {
// critical error
tracing::error!(
client_id,
"Alert: backpressure detected in response stream"
);
state.executor.cancel_request(request_id);
queues.remove(&client_id);
}
ResponseError::ChannelClosed => {
// the first indication the client has disconnected
state.executor.cancel_request(request_id);
queues.remove(&client_id);
}
ResponseError::ResponseQueueDropped => {
// if we get a response for a dropped queue, we need to cancel the request
state.executor.cancel_request(request_id);
}
ResponseError::RequestComplete => {
// no need to cancel, the server will not send any more responses
queues.remove(&client_id);
}
}
}
}
}
}
}
fn try_send(
tx: Option<&mpsc::Sender<Result<protocols::Output>>>,
response: protocols::Response,
) -> Result<(), ResponseError> {
let mut rc = Ok(());
let tx = tx.ok_or(ResponseError::ResponseQueueDropped)?;
let result = match (response.output, response.error_msg) {
(Some(output), None) => {
if output.is_final {
rc = Err(ResponseError::RequestComplete);
}
Ok(output)
}
(None, Some(e)) => {
rc = Err(ResponseError::EngineError(e.clone()));
Err(ResponseError::EngineError(e.clone()))
}
(None, None) => return Err(ResponseError::InvalidResponse),
(Some(_), Some(_)) => return Err(ResponseError::InvalidResponse),
};
match tx.try_send(result.map_err(|e| e.into())) {
Ok(_) => {}
Err(e) => match e {
mpsc::error::TrySendError::Closed(_) => {
return Err(ResponseError::ChannelClosed);
}
mpsc::error::TrySendError::Full(_) => {
return Err(ResponseError::ChannelFull);
}
},
}
rc
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
pub mod kv;
pub mod outputs;
pub mod stats;
pub use outputs::*;
#[derive(Serialize, Deserialize, Default)]
pub struct SamplingConfig {
pub beam_width: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p_min: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p_reset_ids: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p_decay: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub beam_search_diversity_rate: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub repetition_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub length_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub early_stopping: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub no_repeat_ngram_size: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub num_return_sequences: Option<u32>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct OutputConfig {
pub return_log_probs: bool,
pub return_context_logits: bool,
pub return_generation_logits: bool,
pub exclude_input_from_output: bool,
pub return_encoder_output: bool,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct RetentionPriorityAndDuration {
#[serde(skip_serializing_if = "Option::is_none")]
pub retention_priority: Option<u32>, // google.protobuf.UInt32Value
#[serde(skip_serializing_if = "Option::is_none")]
pub duration_ms: Option<u64>, // google.protobuf.UInt64Value
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct TokenRangeRetentionConfig {
pub token_start: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub token_end: Option<u32>, // google.protobuf.UInt32Value
pub priority: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub duration_ms: Option<u64>, // google.protobuf.UInt64Value
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct KvCacheRetentionConfig {
pub token_range_retention_configs: Vec<TokenRangeRetentionConfig>,
pub decode_retention_priority: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub decode_duration_ms: Option<u64>, // google.protobuf.UInt64Value
}
#[derive(Serialize, Deserialize, Debug, Clone, Builder)]
pub struct Request {
pub input_token_ids: Vec<u32>,
pub max_tokens: u32,
pub streaming: bool,
// pub sampling_config: SamplingConfig,
// pub output_config: OutputConfig,
#[serde(skip_serializing_if = "Option::is_none")]
pub end_id: Option<u32>,
// pub pad_id: Option<u32>, // google.protobuf.UInt32Value
// pub position_ids: Vec<u32>,
// pub bad_words: Vec<u32>,
// pub stop_words: Vec<u32>,
// pub embedding_bias: Vec<u8>, // bytes
// // TODO: Add external_draft_tokens_config: ExternalDraftTokensConfig
// // TODO: Add prompt_tuning_config: PromptTuningConfig
// // TODO: Add lora_config: LoraConfig
// // TODO: Add lookahead_config: LookaheadDecodingConfig
// pub kv_cache_retention_config: KvCacheRetentionConfig,
// pub logits_post_processor_name: String,
// pub encoder_input_token_ids: Vec<u32>,
// pub client_id: Option<u64>, // google.protobuf.UInt64Value
// pub return_all_generated_tokens: bool,
// pub priority: f32,
// pub request_type: u32,
// // TODO: Add context_phase_params: ContextPhaseParams
// pub encoder_input_features: Vec<u8>, // bytes
// pub encoder_output_length: Option<u32>, // google.protobuf.UInt32Value
// pub cross_attention_mask: Vec<u8>, // bytes
// pub num_return_sequences: u32,
// // TODO: Add eagle_config: EagleConfig
// pub skip_cross_attn_blocks: Vec<u8>, // bytes
}
// todo - return a Result
impl Request {
pub fn new(input_token_ids: Vec<u32>, max_tokens: u32) -> Self {
RequestBuilder::default()
.input_token_ids(input_token_ids)
.max_tokens(max_tokens)
.streaming(true)
.build()
.unwrap()
}
}
// todo convert to a TryFrom
impl From<dynamo_llm::protocols::common::llm_backend::BackendInput> for Request {
fn from(input: dynamo_llm::protocols::common::llm_backend::BackendInput) -> Self {
let request = RequestBuilder::default()
.input_token_ids(input.token_ids)
.max_tokens(input.stop_conditions.max_tokens.unwrap_or(16))
.streaming(true)
.end_id(input.eos_token_ids.last().cloned())
.build()
.unwrap();
request
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub use dynamo_llm::kv_router::protocols::ForwardPassMetrics;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
use dynamo_llm::protocols::{
common::{self},
TokenIdType,
};
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Responses {
pub responses: Vec<Response>,
pub shutdown: bool,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Response {
pub request_id: u64,
pub client_id: Option<u64>, // Optional client ID.
pub error_msg: Option<String>, // Error message if the request failed.
pub output: Option<Output>, // Output if the request succeeded.
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Output {
pub is_final: bool,
pub token_ids: Vec<TokenIdType>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cum_log_prob: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub log_probs: Option<Vec<f64>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<FinishReasonEnum>,
}
#[derive(Serialize_repr, Deserialize_repr, Debug, Clone)]
#[repr(u8)]
pub enum FinishReasonEnum {
FinishReasonNotDone = 0,
FinishReasonEos = 1,
FinishReasonStop = 2,
FinishReasonLength = 3,
}
impl From<Output> for common::llm_backend::LLMEngineOutput {
fn from(output: Output) -> Self {
let finish_reason = match output.finish_reason {
Some(FinishReasonEnum::FinishReasonNotDone) => None,
Some(FinishReasonEnum::FinishReasonEos) => Some(common::FinishReason::EoS),
Some(FinishReasonEnum::FinishReasonStop) => Some(common::FinishReason::Stop),
Some(FinishReasonEnum::FinishReasonLength) => Some(common::FinishReason::Length),
None => None,
};
common::llm_backend::LLMEngineOutput {
// todo - propagate mdcsum
token_ids: output.token_ids,
tokens: None,
text: None,
cum_log_probs: output.cum_log_prob,
log_probs: None,
finish_reason,
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment