Commit 057f8f47 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: TensorRT-LLM engine (#317)

Engine, `tio` support and docs.

Proof of concept / experimental.
parent 11a36651
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tensorrt_llm/executor/executor.h"
namespace nvidia::nvllm::trt {
std::string serialize_kv_events(std::deque<tensorrt_llm::executor::KVCacheEvent> responses, bool shutdown);
} // namespace nvidia::nvllm::trt
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "engine_trt/request.hpp"
#include <nlohmann/json.hpp>
#include <spdlog/spdlog.h>
#include <cstdint>
#include <optional>
#include <string>
#include <vector>
using json = nlohmann::json;
namespace ex = tensorrt_llm::executor;
namespace nvidia::nvllm::trt {
// SamplingConfig Struct
struct SamplingConfig
{
uint32_t beam_width = 1;
std::optional<uint32_t> top_k;
std::optional<float> top_p;
std::optional<float> top_p_min;
std::optional<uint32_t> top_p_reset_ids;
std::optional<float> top_p_decay;
std::optional<uint32_t> seed;
std::optional<float> temperature;
std::optional<uint32_t> min_tokens;
std::optional<float> beam_search_diversity_rate;
std::optional<float> repetition_penalty;
std::optional<float> presence_penalty;
std::optional<float> frequency_penalty;
std::optional<float> length_penalty;
std::optional<uint32_t> early_stopping;
std::optional<uint32_t> no_repeat_ngram_size;
std::optional<uint32_t> num_return_sequences;
ex::SamplingConfig to_executor_config() const
{
return ex::SamplingConfig(beam_width,
top_k,
top_p,
top_p_min,
top_p_reset_ids,
top_p_decay,
seed,
temperature,
min_tokens,
beam_search_diversity_rate,
repetition_penalty,
presence_penalty,
frequency_penalty,
length_penalty,
early_stopping,
no_repeat_ngram_size,
num_return_sequences);
}
};
// Custom to_json and from_json functions for SamplingConfig
inline void to_json(json& j, const SamplingConfig& s)
{
j = json{{"beam_width", s.beam_width}};
if (s.top_k)
j["top_k"] = s.top_k.value();
if (s.top_p)
j["top_p"] = s.top_p.value();
if (s.top_p_min)
j["top_p_min"] = s.top_p_min.value();
if (s.top_p_reset_ids)
j["top_p_reset_ids"] = s.top_p_reset_ids.value();
if (s.top_p_decay)
j["top_p_decay"] = s.top_p_decay.value();
if (s.seed)
j["seed"] = s.seed.value();
if (s.temperature)
j["temperature"] = s.temperature.value();
if (s.min_tokens)
j["min_tokens"] = s.min_tokens.value();
if (s.beam_search_diversity_rate)
j["beam_search_diversity_rate"] = s.beam_search_diversity_rate.value();
if (s.repetition_penalty)
j["repetition_penalty"] = s.repetition_penalty.value();
if (s.presence_penalty)
j["presence_penalty"] = s.presence_penalty.value();
if (s.frequency_penalty)
j["frequency_penalty"] = s.frequency_penalty.value();
if (s.length_penalty)
j["length_penalty"] = s.length_penalty.value();
if (s.early_stopping)
j["early_stopping"] = s.early_stopping.value();
if (s.no_repeat_ngram_size)
j["no_repeat_ngram_size"] = s.no_repeat_ngram_size.value();
if (s.num_return_sequences)
j["num_return_sequences"] = s.num_return_sequences.value();
}
inline void from_json(const json& j, SamplingConfig& s)
{
j.at("beam_width").get_to(s.beam_width);
if (j.contains("top_k"))
s.top_k = j.at("top_k").get<uint32_t>();
if (j.contains("top_p"))
s.top_p = j.at("top_p").get<float>();
if (j.contains("top_p_min"))
s.top_p_min = j.at("top_p_min").get<float>();
if (j.contains("top_p_reset_ids"))
s.top_p_reset_ids = j.at("top_p_reset_ids").get<uint32_t>();
if (j.contains("top_p_decay"))
s.top_p_decay = j.at("top_p_decay").get<float>();
if (j.contains("seed"))
s.seed = j.at("seed").get<uint32_t>();
if (j.contains("temperature"))
s.temperature = j.at("temperature").get<float>();
if (j.contains("min_tokens"))
s.min_tokens = j.at("min_tokens").get<uint32_t>();
if (j.contains("beam_search_diversity_rate"))
s.beam_search_diversity_rate = j.at("beam_search_diversity_rate").get<float>();
if (j.contains("repetition_penalty"))
s.repetition_penalty = j.at("repetition_penalty").get<float>();
if (j.contains("presence_penalty"))
s.presence_penalty = j.at("presence_penalty").get<float>();
if (j.contains("frequency_penalty"))
s.frequency_penalty = j.at("frequency_penalty").get<float>();
if (j.contains("length_penalty"))
s.length_penalty = j.at("length_penalty").get<float>();
if (j.contains("early_stopping"))
s.early_stopping = j.at("early_stopping").get<uint32_t>();
if (j.contains("no_repeat_ngram_size"))
s.no_repeat_ngram_size = j.at("no_repeat_ngram_size").get<uint32_t>();
if (j.contains("num_return_sequences"))
s.num_return_sequences = j.at("num_return_sequences").get<uint32_t>();
}
// OutputConfig Struct
struct OutputConfig
{
bool return_log_probs;
bool return_context_logits;
bool return_generation_logits;
bool exclude_input_from_output;
bool return_encoder_output;
};
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(OutputConfig,
return_log_probs,
return_context_logits,
return_generation_logits,
exclude_input_from_output,
return_encoder_output)
// RetentionPriorityAndDuration Struct
struct RetentionPriorityAndDuration
{
std::optional<uint32_t> retention_priority;
std::optional<uint64_t> duration_ms;
};
inline void to_json(json& j, const RetentionPriorityAndDuration& r)
{
if (r.retention_priority)
j["retention_priority"] = r.retention_priority.value();
if (r.duration_ms)
j["duration_ms"] = r.duration_ms.value();
}
inline void from_json(const json& j, RetentionPriorityAndDuration& r)
{
if (j.contains("retention_priority"))
r.retention_priority = j.at("retention_priority").get<uint32_t>();
if (j.contains("duration_ms"))
r.duration_ms = j.at("duration_ms").get<uint64_t>();
}
// TokenRangeRetentionConfig Struct
struct TokenRangeRetentionConfig
{
uint32_t token_start;
std::optional<uint32_t> token_end;
uint32_t priority;
std::optional<uint64_t> duration_ms;
};
inline void to_json(json& j, const TokenRangeRetentionConfig& t)
{
j = json{{"token_start", t.token_start}, {"priority", t.priority}};
if (t.token_end)
j["token_end"] = t.token_end.value();
if (t.duration_ms)
j["duration_ms"] = t.duration_ms.value();
}
inline void from_json(const json& j, TokenRangeRetentionConfig& t)
{
j.at("token_start").get_to(t.token_start);
j.at("priority").get_to(t.priority);
if (j.contains("token_end"))
t.token_end = j.at("token_end").get<uint32_t>();
if (j.contains("duration_ms"))
t.duration_ms = j.at("duration_ms").get<uint64_t>();
}
// KvCacheRetentionConfig Struct
struct KvCacheRetentionConfig
{
std::vector<TokenRangeRetentionConfig> token_range_retention_configs;
uint32_t decode_retention_priority;
std::optional<uint64_t> decode_duration_ms;
};
inline void to_json(json& j, const KvCacheRetentionConfig& k)
{
j = json{{"token_range_retention_configs", k.token_range_retention_configs},
{"decode_retention_priority", k.decode_retention_priority}};
if (k.decode_duration_ms)
j["decode_duration_ms"] = k.decode_duration_ms.value();
}
inline void from_json(const json& j, KvCacheRetentionConfig& k)
{
j.at("token_range_retention_configs").get_to(k.token_range_retention_configs);
j.at("decode_retention_priority").get_to(k.decode_retention_priority);
if (j.contains("decode_duration_ms"))
k.decode_duration_ms = j.at("decode_duration_ms").get<uint64_t>();
}
// Request Struct
struct Request
{
std::vector<int32_t> input_token_ids;
uint32_t max_tokens;
bool streaming;
std::optional<SamplingConfig> sampling_config;
std::optional<OutputConfig> output_config;
std::optional<uint32_t> end_id;
// std::optional<uint32_t> pad_id;
// std::vector<uint32_t> position_ids;
// std::vector<uint32_t> bad_words;
// std::vector<uint32_t> stop_words;
// std::vector<uint8_t> embedding_bias; // bytes
// // TODO: Add ExternalDraftTokensConfig external_draft_tokens_config;
// // TODO: Add PromptTuningConfig prompt_tuning_config;
// // TODO: Add LoraConfig lora_config;
// // TODO: Add LookaheadDecodingConfig lookahead_config;
// KvCacheRetentionConfig kv_cache_retention_config;
// std::string logits_post_processor_name;
// std::vector<uint32_t> encoder_input_token_ids;
// std::optional<uint64_t> client_id;
// bool return_all_generated_tokens;
// float priority;
// uint32_t request_type;
// // TODO: Add ContextPhaseParams context_phase_params;
// std::vector<uint8_t> encoder_input_features; // bytes
// std::optional<uint32_t> encoder_output_length;
// std::vector<uint8_t> cross_attention_mask; // bytes
// uint32_t num_return_sequences;
// // TODO: Add EagleConfig eagle_config;
// std::vector<uint8_t> skip_cross_attn_blocks; // bytes
};
// Custom to_json and from_json functions for Request
inline void to_json(json& j, const Request& r)
{
j = json{
{"input_token_ids", r.input_token_ids},
{"max_tokens", r.max_tokens},
{"streaming", r.streaming},
// {"sampling_config", r.sampling_config},
// {"output_config", r.output_config},
// {"position_ids", r.position_ids},
// {"bad_words", r.bad_words},
// {"stop_words", r.stop_words},
// {"kv_cache_retention_config", r.kv_cache_retention_config},
// {"logits_post_processor_name", r.logits_post_processor_name},
// {"encoder_input_token_ids", r.encoder_input_token_ids},
// {"return_all_generated_tokens", r.return_all_generated_tokens},
// {"priority", r.priority},
// {"request_type", r.request_type},
// {"num_return_sequences", r.num_return_sequences}
};
if (r.sampling_config)
j["sampling_config"] = r.sampling_config.value();
if (r.output_config)
j["output_config"] = r.output_config.value();
if (r.end_id)
j["end_id"] = r.end_id.value();
// if (r.pad_id)
// j["pad_id"] = r.pad_id.value();
// if (!r.embedding_bias.empty())
// j["embedding_bias"] = r.embedding_bias;
// if (r.client_id)
// j["client_id"] = r.client_id.value();
// if (!r.encoder_input_features.empty())
// j["encoder_input_features"] = r.encoder_input_features;
// if (r.encoder_output_length)
// j["encoder_output_length"] = r.encoder_output_length.value();
// if (!r.cross_attention_mask.empty())
// j["cross_attention_mask"] = r.cross_attention_mask;
// if (!r.skip_cross_attn_blocks.empty())
// j["skip_cross_attn_blocks"] = r.skip_cross_attn_blocks;
}
inline void from_json(const json& j, Request& r)
{
j.at("input_token_ids").get_to(r.input_token_ids);
j.at("max_tokens").get_to(r.max_tokens);
j.at("streaming").get_to(r.streaming);
if (j.contains("sampling_config"))
r.sampling_config = j.at("sampling_config").get<SamplingConfig>();
if (j.contains("output_config"))
r.output_config = j.at("output_config").get<OutputConfig>();
// j.at("sampling_config").get_to(r.sampling_config);
// j.at("output_config").get_to(r.output_config);
// j.at("position_ids").get_to(r.position_ids);
// j.at("bad_words").get_to(r.bad_words);
// j.at("stop_words").get_to(r.stop_words);
// j.at("kv_cache_retention_config").get_to(r.kv_cache_retention_config);
// j.at("logits_post_processor_name").get_to(r.logits_post_processor_name);
// j.at("encoder_input_token_ids").get_to(r.encoder_input_token_ids);
// j.at("return_all_generated_tokens").get_to(r.return_all_generated_tokens);
// j.at("priority").get_to(r.priority);
// j.at("request_type").get_to(r.request_type);
// j.at("num_return_sequences").get_to(r.num_return_sequences);
if (j.contains("end_id"))
r.end_id = j.at("end_id").get<uint32_t>();
// if (j.contains("pad_id"))
// r.pad_id = j.at("pad_id").get<uint32_t>();
// if (j.contains("embedding_bias"))
// r.embedding_bias = j.at("embedding_bias").get<std::vector<uint8_t>>();
// if (j.contains("client_id"))
// r.client_id = j.at("client_id").get<uint64_t>();
// if (j.contains("encoder_input_features"))
// r.encoder_input_features = j.at("encoder_input_features").get<std::vector<uint8_t>>();
// if (j.contains("encoder_output_length"))
// r.encoder_output_length = j.at("encoder_output_length").get<uint32_t>();
// if (j.contains("cross_attention_mask"))
// r.cross_attention_mask = j.at("cross_attention_mask").get<std::vector<uint8_t>>();
// if (j.contains("skip_cross_attn_blocks"))
// r.skip_cross_attn_blocks = j.at("skip_cross_attn_blocks").get<std::vector<uint8_t>>();
}
tensorrt_llm::executor::Request deserialize_request(const std::string& request_proto)
{
spdlog::trace("Deserializing request json: {}", request_proto);
auto j = json::parse(request_proto);
auto req_in = j.get<Request>();
spdlog::trace("constructing request with {} input tokens; max tokens: {}",
req_in.input_token_ids.size(),
req_in.max_tokens);
tensorrt_llm::executor::Request request(std::move(req_in.input_token_ids), req_in.max_tokens, true);
if (req_in.sampling_config)
{
spdlog::trace("Setting sampling_config");
request.setSamplingConfig(req_in.sampling_config->to_executor_config());
}
if (req_in.end_id)
{
spdlog::trace("Setting end_id: {}", req_in.end_id.value());
request.setEndId(req_in.end_id.value());
}
return request;
}
} // namespace nvidia::nvllm::trt
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tensorrt_llm/executor/executor.h"
namespace nvidia::nvllm::trt {
tensorrt_llm::executor::Request deserialize_request(const std::string& request);
} // namespace nvidia::nvllm::trt
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "engine_trt/response.hpp"
#include <nlohmann/json.hpp>
#include <spdlog/spdlog.h>
#include <optional>
#include <string>
#include <vector>
using json = nlohmann::json;
namespace ex = tensorrt_llm::executor;
namespace nvidia::nvllm::trt {
// Forward declarations
struct Response;
struct Output;
enum FinishReasonEnum
{
FINISH_REASON_NOT_DONE = 0,
FINISH_REASON_EOS = 1,
FINISH_REASON_STOP = 2,
FINISH_REASON_LENGTH = 3,
};
// Output Struct
struct Output
{
bool is_final;
std::vector<int32_t> token_ids;
std::optional<float> cum_log_prob;
std::optional<std::vector<float>> log_probs;
std::optional<FinishReasonEnum> finish_reason;
};
// Custom to_json function
void to_json(json& j, const Output& o)
{
j = json{{"is_final", o.is_final}, {"token_ids", o.token_ids}};
if (o.cum_log_prob)
{
j["cum_log_prob"] = *o.cum_log_prob;
}
if (o.log_probs)
{
j["log_probs"] = *o.log_probs;
}
if (o.finish_reason)
{
j["finish_reason"] = static_cast<int>(*o.finish_reason);
}
}
void from_json(const json& j, Output& o)
{
j.at("is_final").get_to(o.is_final);
j.at("token_ids").get_to(o.token_ids);
if (j.contains("cum_log_prob") && !j["cum_log_prob"].is_null())
{
o.cum_log_prob = j["cum_log_prob"].get<float>();
}
else
{
o.cum_log_prob = std::nullopt;
}
if (j.contains("log_probs") && !j["log_probs"].is_null())
{
o.log_probs = j["log_probs"].get<std::vector<float>>();
}
else
{
o.log_probs = std::nullopt;
}
if (j.contains("finish_reason") && !j["finish_reason"].is_null())
{
o.finish_reason = static_cast<FinishReasonEnum>(j["finish_reason"].get<int>());
}
else
{
o.finish_reason = std::nullopt;
}
}
// Response Struct
struct Response
{
uint64_t request_id;
std::optional<uint64_t> client_id; // Optional client ID.
std::optional<std::string> error_msg;
std::optional<Output> output;
};
inline void to_json(json& j, const Response& p)
{
j = json{{"request_id", p.request_id}};
if (p.client_id)
j["client_id"] = p.client_id.value();
if (p.error_msg)
j["error_msg"] = p.error_msg.value();
if (p.output)
j["output"] = p.output.value();
}
inline void from_json(const json& j, Response& p)
{
j.at("request_id").get_to(p.request_id);
if (j.contains("client_id"))
p.client_id = j.at("client_id").get<uint64_t>();
if (j.contains("error_msg"))
p.error_msg = j.at("error_msg").get<std::string>();
if (j.contains("output"))
p.output = j.at("output").get<Output>();
}
// Responses Struct
struct Responses
{
std::vector<Response> responses;
bool shutdown;
};
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Responses, responses, shutdown)
Response convert(ex::Response&& response)
{
auto request_id = response.getRequestId();
auto client_id = response.getClientId();
if (response.hasError())
{
auto error_msg = response.getErrorMsg();
return Response{request_id, client_id, {error_msg}, std::nullopt};
}
auto e_output = response.getResult();
auto is_final = e_output.isFinal;
assert(e_output.outputTokenIds.size() == 1);
auto token_ids = std::move(e_output.outputTokenIds[0]);
auto output = Output{is_final, std::move(token_ids), std::nullopt, std::nullopt, std::nullopt};
if (e_output.cumLogProbs.has_value())
{
assert(e_output.cumLogProbs.value().size() == 1);
output.cum_log_prob = {e_output.cumLogProbs.value()[0]};
}
if (e_output.logProbs.has_value())
{
assert(e_output.logProbs.value().size() == 1);
output.log_probs = {std::move(e_output.logProbs.value()[0])};
}
if (e_output.finishReasons.size() > 0)
{
assert(e_output.finishReasons.size() == 1);
auto finish_reason = static_cast<FinishReasonEnum>(e_output.finishReasons[0]);
if (finish_reason != FinishReasonEnum::FINISH_REASON_NOT_DONE)
{
output.finish_reason = {finish_reason};
}
}
return Response{request_id, client_id, std::nullopt, {output}};
}
std::string serialize_responses(std::deque<ex::Response> responses, bool shutdown)
{
auto object = Responses{};
object.shutdown = shutdown;
while (!responses.empty())
{
auto response = std::move(responses.front());
responses.pop_front();
auto r = convert(std::move(response));
assert(r.output.has_value() || r.error_msg.has_value());
object.responses.emplace_back(std::move(r));
}
return json(object).dump();
}
} // namespace nvidia::nvllm::trt
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tensorrt_llm/executor/executor.h"
namespace nvidia::nvllm::trt {
std::string serialize_responses(std::deque<tensorrt_llm::executor::Response> responses, bool shutdown);
} // namespace nvidia::nvllm::trt
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "engine_trt/stats.hpp"
#include <nlohmann/json.hpp>
#include <deque>
using json = nlohmann::json;
namespace nvidia::nvllm::trt {
std::string serialize_iter_stats(std::deque<::tensorrt_llm::executor::IterationStats> stats)
{
json json_stats = json::array();
for (const auto& stat : stats)
{
if (stat.kvCacheStats.has_value())
{
json entry;
entry["iter"] = stat.iter;
entry["kv_active_blocks"] = stat.kvCacheStats->usedNumBlocks;
entry["kv_total_blocks"] = stat.kvCacheStats->maxNumBlocks;
entry["request_active_slots"] = stat.numActiveRequests;
entry["request_total_slots"] = stat.maxNumActiveRequests;
entry["request_new_active_slots"] = stat.numNewActiveRequests;
json_stats.push_back(entry);
}
else
{
json entry;
entry["iter"] = stat.iter;
entry["request_active_slots"] = stat.numActiveRequests;
entry["request_total_slots"] = stat.maxNumActiveRequests;
entry["request_new_active_slots"] = stat.numNewActiveRequests;
json_stats.push_back(entry);
}
}
return json_stats.dump();
}
} // namespace nvidia::nvllm::trt
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tensorrt_llm/executor/executor.h"
namespace nvidia::nvllm::trt {
std::string serialize_iter_stats(std::deque<tensorrt_llm::executor::IterationStats> stats);
} // namespace nvidia::nvllm::trt
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nvidia/nvllm/nvllm_trt.h"
#include "api/engine.hpp"
#include <cstring>
extern "C" {
// int trtllm_mpi_session_set_communicator(void* world_comm_ptr)
// {
// return nvidia::nvllm::trt::MpiSession::set_communicator(world_comm_ptr);
// }
nvllm_trt_engine_t nvllm_trt_engine_create(const char* config_proto)
{
// based on the type of engine, we might choose to create a different concrete engine object
try
{
return reinterpret_cast<nvllm_trt_engine_t>(new nvidia::nvllm::trt::StreamingEngine(std::string(config_proto)));
} catch (const std::exception& e)
{
printf("Caught exception when initializing tensorrt_llm: %s\n", e.what());
return nullptr;
}
}
nvllm_trt_engine_t nvllm_trt_engine_unsafe_create_from_executor(void* engine)
{
try
{
return reinterpret_cast<nvllm_trt_engine_t>(new nvidia::nvllm::trt::StreamingEngine(engine));
} catch (const std::exception& e)
{
printf("Caught exception when initializing from raw pointer: %s\n", e.what());
return nullptr;
}
}
request_id_t nvllm_trt_engine_enqueue_request(nvllm_trt_engine_t engine, client_id_t client_id, const char* req_proto)
{
// Call the enqueue_request method on the C++ class
try
{
return reinterpret_cast<nvidia::nvllm::trt::StreamingEngine*>(engine)->enqueue_request(client_id,
std::string(req_proto));
} catch (...)
{
return 0;
}
}
char* nvllm_trt_engine_await_responses(nvllm_trt_engine_t engine)
{
auto responses = reinterpret_cast<nvidia::nvllm::trt::StreamingEngine*>(engine)->await_responses();
char* c_responses = strdup(responses.c_str()); // Allocate memory and copy the string
return c_responses; // Return the C string (remember to free this in the calling code)
}
char* nvllm_trt_engine_await_kv_events(nvllm_trt_engine_t engine)
{
auto responses = reinterpret_cast<nvidia::nvllm::trt::StreamingEngine*>(engine)->await_kv_events();
if (!responses)
{
return nullptr;
}
char* c_responses = strdup(responses->c_str()); // Allocate memory and copy the string
return c_responses; // Return the C string (remember to free this in the calling code)
}
// Get basic iteration stats
char* nvllm_trt_engine_await_iter_stats(nvllm_trt_engine_t engine)
{
auto responses = reinterpret_cast<nvidia::nvllm::trt::StreamingEngine*>(engine)->await_iter_stats();
if (!responses)
{
return nullptr;
}
char* c_responses = strdup(responses->c_str());
return c_responses;
}
void nvllm_trt_engine_free_responses(char* responses)
{
free(responses);
}
void nvllm_trt_engine_cancel_request(nvllm_trt_engine_t engine, uint64_t request_id)
{
reinterpret_cast<nvidia::nvllm::trt::StreamingEngine*>(engine)->cancel_request(request_id);
}
void nvllm_trt_engine_shutdown(nvllm_trt_engine_t engine)
{
reinterpret_cast<nvidia::nvllm::trt::StreamingEngine*>(engine)->shutdown();
}
int nvllm_trt_engine_destroy(nvllm_trt_engine_t engine)
{
auto* trtllm_engine = reinterpret_cast<nvidia::nvllm::trt::StreamingEngine*>(engine);
delete trtllm_engine;
return NVLLM_TRT_ENGINE_SUCCESS;
}
int nvllm_trt_engine_is_ready(nvllm_trt_engine_t engine)
{
return reinterpret_cast<nvidia::nvllm::trt::StreamingEngine*>(engine)->is_ready();
}
int nvllm_trt_engine_has_completed(nvllm_trt_engine_t engine)
{
return reinterpret_cast<nvidia::nvllm::trt::StreamingEngine*>(engine)->has_completed();
}
// int trtllm_version_major()
// {
// return TRTLLM_VERSION_MAJOR;
// }
// int trtllm_version_minor()
// {
// return TRTLLM_VERSION_MINOR;
// }
// int trtllm_version_patch()
// {
// return TRTLLM_VERSION_PATCH;
// }
}
...@@ -404,6 +404,26 @@ version = "1.6.0" ...@@ -404,6 +404,26 @@ version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
[[package]]
name = "bindgen"
version = "0.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f49d8fed880d473ea71efb9bf597651e77201bdd4893efe54c9e5d65ae04ce6f"
dependencies = [
"bitflags 2.8.0",
"cexpr",
"clang-sys",
"itertools 0.13.0",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash 1.1.0",
"shlex",
"syn 2.0.98",
]
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "1.3.2" version = "1.3.2"
...@@ -498,6 +518,15 @@ dependencies = [ ...@@ -498,6 +518,15 @@ dependencies = [
"shlex", "shlex",
] ]
[[package]]
name = "cexpr"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
dependencies = [
"nom",
]
[[package]] [[package]]
name = "cfg-expr" name = "cfg-expr"
version = "0.15.8" version = "0.15.8"
...@@ -541,6 +570,17 @@ dependencies = [ ...@@ -541,6 +570,17 @@ dependencies = [
"windows-targets 0.52.6", "windows-targets 0.52.6",
] ]
[[package]]
name = "clang-sys"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
dependencies = [
"glob",
"libc",
"libloading",
]
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.28" version = "4.5.28"
...@@ -568,6 +608,15 @@ version = "0.7.4" ...@@ -568,6 +608,15 @@ version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6"
[[package]]
name = "cmake"
version = "0.1.54"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0"
dependencies = [
"cc",
]
[[package]] [[package]]
name = "colorchoice" name = "colorchoice"
version = "1.0.3" version = "1.0.3"
...@@ -1296,6 +1345,12 @@ version = "0.31.1" ...@@ -1296,6 +1345,12 @@ version = "0.31.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
[[package]]
name = "glob"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
[[package]] [[package]]
name = "h2" name = "h2"
version = "0.4.7" version = "0.4.7"
...@@ -1824,6 +1879,16 @@ version = "0.2.169" ...@@ -1824,6 +1879,16 @@ version = "0.2.169"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a"
[[package]]
name = "libloading"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
dependencies = [
"cfg-if 1.0.0",
"windows-targets 0.52.6",
]
[[package]] [[package]]
name = "libredox" name = "libredox"
version = "0.1.3" version = "0.1.3"
...@@ -2705,7 +2770,7 @@ dependencies = [ ...@@ -2705,7 +2770,7 @@ dependencies = [
"pin-project-lite", "pin-project-lite",
"quinn-proto", "quinn-proto",
"quinn-udp", "quinn-udp",
"rustc-hash", "rustc-hash 2.1.1",
"rustls", "rustls",
"socket2", "socket2",
"thiserror 2.0.11", "thiserror 2.0.11",
...@@ -2723,7 +2788,7 @@ dependencies = [ ...@@ -2723,7 +2788,7 @@ dependencies = [
"getrandom 0.2.15", "getrandom 0.2.15",
"rand", "rand",
"ring", "ring",
"rustc-hash", "rustc-hash 2.1.1",
"rustls", "rustls",
"rustls-pki-types", "rustls-pki-types",
"slab", "slab",
...@@ -2964,6 +3029,12 @@ version = "0.1.24" ...@@ -2964,6 +3029,12 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]] [[package]]
name = "rustc-hash" name = "rustc-hash"
version = "2.1.1" version = "2.1.1"
...@@ -3942,10 +4013,12 @@ dependencies = [ ...@@ -3942,10 +4013,12 @@ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
"axum 0.8.1", "axum 0.8.1",
"bindgen",
"blake3", "blake3",
"bs62", "bs62",
"bytes", "bytes",
"chrono", "chrono",
"cmake",
"derive_builder", "derive_builder",
"either", "either",
"erased-serde", "erased-serde",
...@@ -3963,6 +4036,7 @@ dependencies = [ ...@@ -3963,6 +4036,7 @@ dependencies = [
"serde", "serde",
"serde-pickle", "serde-pickle",
"serde_json", "serde_json",
"serde_repr",
"strum", "strum",
"thiserror 2.0.11", "thiserror 2.0.11",
"tokenizers", "tokenizers",
......
...@@ -466,6 +466,26 @@ dependencies = [ ...@@ -466,6 +466,26 @@ dependencies = [
"which", "which",
] ]
[[package]]
name = "bindgen"
version = "0.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f49d8fed880d473ea71efb9bf597651e77201bdd4893efe54c9e5d65ae04ce6f"
dependencies = [
"bitflags 2.8.0",
"cexpr",
"clang-sys",
"itertools 0.13.0",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash 1.1.0",
"shlex",
"syn 2.0.98",
]
[[package]] [[package]]
name = "bindgen_cuda" name = "bindgen_cuda"
version = "0.1.5" version = "0.1.5"
...@@ -2757,7 +2777,7 @@ version = "0.1.102" ...@@ -2757,7 +2777,7 @@ version = "0.1.102"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0522f9894e22dd988dd2e34222bda7acba53a0dcce744ca6d8ddce905ba33a4e" checksum = "0522f9894e22dd988dd2e34222bda7acba53a0dcce744ca6d8ddce905ba33a4e"
dependencies = [ dependencies = [
"bindgen", "bindgen 0.69.5",
"cc", "cc",
"cmake", "cmake",
"find_cuda_helper", "find_cuda_helper",
...@@ -5737,10 +5757,12 @@ dependencies = [ ...@@ -5737,10 +5757,12 @@ dependencies = [
"async-trait", "async-trait",
"async_zmq", "async_zmq",
"axum 0.8.1", "axum 0.8.1",
"bindgen 0.70.1",
"blake3", "blake3",
"bs62", "bs62",
"bytes", "bytes",
"chrono", "chrono",
"cmake",
"derive_builder", "derive_builder",
"either", "either",
"erased-serde", "erased-serde",
...@@ -5766,6 +5788,7 @@ dependencies = [ ...@@ -5766,6 +5788,7 @@ dependencies = [
"serde", "serde",
"serde-pickle", "serde-pickle",
"serde_json", "serde_json",
"serde_repr",
"strum 0.27.1", "strum 0.27.1",
"tempfile", "tempfile",
"thiserror 2.0.11", "thiserror 2.0.11",
......
...@@ -35,6 +35,7 @@ llamacpp = ["dep:llama-cpp-2"] ...@@ -35,6 +35,7 @@ llamacpp = ["dep:llama-cpp-2"]
sglang = ["dep:async_zmq"] sglang = ["dep:async_zmq"]
sentencepiece = ["dep:sentencepiece"] sentencepiece = ["dep:sentencepiece"]
vllm = ["dep:async_zmq"] vllm = ["dep:async_zmq"]
trtllm = []
cuda = ["mistralrs/cuda", "llama-cpp-2/cuda"] cuda = ["mistralrs/cuda", "llama-cpp-2/cuda"]
metal = ["mistralrs/metal", "llama-cpp-2/metal"] metal = ["mistralrs/metal", "llama-cpp-2/metal"]
...@@ -143,6 +144,9 @@ minijinja = { version = "2.3.1", features = ["loader"] } ...@@ -143,6 +144,9 @@ minijinja = { version = "2.3.1", features = ["loader"] }
minijinja-contrib = { version = "2.3.1", features = ["pycompat"] } minijinja-contrib = { version = "2.3.1", features = ["pycompat"] }
semver = { version = "1", features = ["serde"] } semver = { version = "1", features = ["serde"] }
# trtllm
serde_repr = "0.1"
[dev-dependencies] [dev-dependencies]
proptest = "1.5.0" proptest = "1.5.0"
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] } reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
...@@ -156,5 +160,9 @@ insta = { version = "1.41", features = [ ...@@ -156,5 +160,9 @@ insta = { version = "1.41", features = [
"filters", "filters",
] } ] }
[build-dependencies]
bindgen = "0.70"
cmake = "0.1"
[profile.dev.package] [profile.dev.package]
insta.opt-level = 3 insta.opt-level = 3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#[cfg(not(feature = "trtllm"))]
fn main() {}
#[cfg(feature = "trtllm")]
fn main() {
extern crate bindgen;
use cmake::Config;
use std::env;
use std::path::PathBuf;
let installed_headers = "/usr/local/include/nvidia/nvllm/nvllm_trt.h";
let local_headers = "../bindings/cpp/nvllm-trt/include/nvidia/nvllm/nvllm_trt.h";
let headers_path;
if PathBuf::from(installed_headers).exists() {
headers_path = installed_headers;
println!("cargo:warning=nvllm found. Building with installed version...");
println!("cargo:rustc-link-search=native=/usr/local/lib");
println!("cargo:rustc-link-search=native=/opt/tensorrt_llm/lib");
println!("cargo:rustc-link-lib=dylib=nvllm_trt");
println!("cargo:rustc-link-lib=dylib=tensorrt_llm");
println!("cargo:rustc-link-lib=dylib=tensorrt_llm_nvrtc_wrapper");
println!("cargo:rustc-link-lib=dylib=nvinfer_plugin_tensorrt_llm");
println!("cargo:rustc-link-lib=dylib=decoder_attention");
println!("cargo:rerun-if-changed=/usr/local/lib");
} else if PathBuf::from(local_headers).exists() {
headers_path = local_headers;
println!("cargo:warning=nvllm not found. Building stub version...");
let dst = Config::new("../bindings/cpp/nvllm-trt")
.define("USE_STUBS", "ON")
.no_build_target(true)
.build();
println!("cargo:warning=building stubs in {}", dst.display());
let dst = dst.canonicalize().unwrap();
println!("cargo:rustc-link-search=native={}/build", dst.display());
println!("cargo:rustc-link-lib=dylib=nvllm_trt");
println!("cargo:rustc-link-lib=dylib=tensorrt_llm");
println!("cargo:rerun-if-changed=../bindings/cpp/nvllm-trt");
} else {
panic!("nvllm_trt.h not found");
}
// generate bindings for the trtllm c api
let bindings = bindgen::Builder::default()
.header(headers_path)
.generate()
.expect("Unable to generate bindings");
// Write the bindings to a file
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
bindings
.write_to_file(out_path.join("bindings.rs"))
.expect("Could not write bindings!");
// // Build protobuf
// tonic_build::configure()
// .build_server(false)
// .compile_protos(&["../../proto/trtllm.proto"], &["../../proto"])
// .expect("Failed to compile protos");
}
...@@ -24,3 +24,6 @@ pub mod llamacpp; ...@@ -24,3 +24,6 @@ pub mod llamacpp;
#[cfg(feature = "vllm")] #[cfg(feature = "vllm")]
pub mod vllm; pub mod vllm;
#[cfg(feature = "trtllm")]
pub mod trtllm;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use crate::backend::ExecutionContext;
use triton_distributed_runtime::pipeline::error as pipeline_error;
pub mod executor;
/// Create a TRT-LLM engine.
pub fn make_engine<P: ToString>(
// A full repo with .engine files, config.json,
model_path: P,
// How many GPUs to use
tensor_parallel_size: u32,
) -> pipeline_error::Result<ExecutionContext> {
let config = executor::config::ExecutorConfig::builder()
.model_path(model_path.to_string())
.tensor_parallel_size(Some(tensor_parallel_size))
.build()?;
let engine = executor::Executor::new(config)?;
engine.start_response_processor();
engine.start_kv_event_processor();
engine.start_iteration_metrics_processor();
let engine: ExecutionContext = Arc::new(engine);
Ok(engine)
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
mod cpp;
mod engine;
mod processors;
// pub mod protos {
// include!(concat!(env!("OUT_DIR"), "/nvidia.nvllm.trt.proto.rs"));
// }
pub mod protocols;
pub mod config;
use anyhow::Result;
use std::{
collections::HashMap,
ffi::CString,
sync::{atomic::AtomicU64, Arc, Mutex, OnceLock, Weak},
};
use tokio::sync::mpsc;
use processors::{
IterationProcessor, IterationStatsSubscriptionChannel, KvEventProcessor,
KvEventSubscriptionChannel, ProcessorState, ResponseProcessor,
};
pub struct Executor {
executor: Arc<cpp::Executor>,
next_id: AtomicU64,
response_queues: ResponseQueues,
response_processor: OnceLock<ResponseProcessor>,
kv_event_processor: OnceLock<KvEventProcessor>,
iteration_processor: OnceLock<IterationProcessor>,
}
type ResponseQueues = Arc<Mutex<HashMap<u64, mpsc::Sender<Result<protocols::Output>>>>>;
impl Executor {
pub fn from_model_path<P: ToString>(model_path: P) -> Result<Self> {
let config = config::ExecutorConfig::new(model_path.to_string());
Self::new(config)
}
pub fn new(config: config::ExecutorConfig) -> Result<Self> {
Ok(Self {
executor: Arc::new(cpp::Executor::new(config)?),
next_id: AtomicU64::new(0),
response_queues: Arc::new(Mutex::new(HashMap::new())),
response_processor: OnceLock::new(),
kv_event_processor: OnceLock::new(),
iteration_processor: OnceLock::new(),
})
}
pub fn has_started(&self) -> bool {
self.executor.has_started()
}
pub fn has_completed(&self) -> bool {
self.executor.has_completed()
}
pub fn enqueue_request(&self, request: protocols::Request) -> Result<ExecutionContext> {
let client_id = self
.next_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let (tx, rx) = mpsc::channel(128);
self.response_queues
.lock()
.expect("response_queues lock poisoned")
.insert(client_id, tx);
let json = serde_json::to_string(&request)?;
let str = CString::new(json)?;
let request_id = self
.executor
.enqueue_request(client_id, str)
.inspect_err(|_| {
self.response_queues
.lock()
.expect("response_queues lock poisoned")
.remove(&client_id);
})?;
println!("request_id: {}", request_id);
Ok(ExecutionContext {
request_id,
response_rx: Some(rx),
executor: Arc::downgrade(&self.executor),
})
}
pub fn cancel_request(&self, client_id: u64) {
self.executor.cancel_request(client_id)
}
/// Start a background task to process responses from the TensorRT LLM AsyncEngine
pub fn start_response_processor(&self) {
self.response_processor.get_or_init(|| {
ResponseProcessor::new(self.create_processor(), self.response_queues.clone())
});
}
/// Starts a background task to process kv events
/// TODO - check the TensorRT LLM config and only start this if the server is configured to send kv events
pub fn start_kv_event_processor(&self) {
self.kv_event_processor
.get_or_init(|| KvEventProcessor::new(self.create_processor()));
}
/// Starts a background task to process forward pass / iteration statistics
pub fn start_iteration_metrics_processor(&self) {
self.iteration_processor
.get_or_init(|| IterationProcessor::new(self.create_processor()));
}
/// Subscribes to the KV Events broadcast channel
pub fn subscribe_to_kv_events(&self) -> Result<KvEventSubscriptionChannel> {
self.kv_event_processor
.get_or_init(|| KvEventProcessor::new(self.create_processor()))
.subscribe()
.ok_or(anyhow::anyhow!("Failed to subscribe to KV events"))
}
pub fn subscribe_to_iteration_stats(&self) -> Result<IterationStatsSubscriptionChannel> {
self.iteration_processor
.get_or_init(|| IterationProcessor::new(self.create_processor()))
.subscribe()
.ok_or(anyhow::anyhow!("Failed to subscribe to iteration stats"))
}
/// Issues a shutdown request to the TensorRT LLM AsyncEngine
/// This is a blocking call. After the async engine has shutdown each background processor/thread/task
/// will be joined and the resources will be released.
pub fn shutdown(&mut self) {
self.executor.shutdown();
self.response_processor.take().map(|p| p.join());
self.kv_event_processor.take().map(|p| p.join());
self.iteration_processor.take().map(|p| p.join());
}
/// Constructs a new ProcessorState instance which packages up any bits from the Executor for the processor task
fn create_processor(&self) -> ProcessorState {
ProcessorState::new(self.executor.clone())
}
}
impl Drop for Executor {
fn drop(&mut self) {
self.shutdown();
}
}
pub struct ExecutionContext {
/// Internal TensorRT LLM request_id; used to cancel the request
/// This value is present in the response but because we do not know it before hand, it is only used for cancellation
request_id: u64,
/// Hold a weak pointer to the executor for cancellation
executor: Weak<cpp::Executor>,
/// Response stream associated with this request
response_rx: Option<mpsc::Receiver<Result<protocols::Output>>>,
}
impl ExecutionContext {
pub fn cancel(&self) {
if let Some(executor) = self.executor.upgrade() {
executor.cancel_request(self.request_id);
}
}
pub fn take_response_rx(&mut self) -> Option<mpsc::Receiver<Result<protocols::Output>>> {
self.response_rx.take()
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, Default, Builder)]
pub struct ExecutorConfig {
model_path: String,
#[builder(default = "LogLevel::Error")]
log_level: LogLevel,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default)]
enable_chunked_context: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default)]
normalize_log_probs: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default)]
iter_stats_max_iterations: Option<u32>,
/// The number of processes for tensor parallelism. Defaults to 1.
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default)]
tensor_parallel_size: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum LogLevel {
#[default]
Error,
Warn,
Info,
Debug,
Trace,
}
impl From<&str> for LogLevel {
fn from(value: &str) -> Self {
match value.to_lowercase().as_str() {
"error" => LogLevel::Error,
"warn" => LogLevel::Warn,
"info" => LogLevel::Info,
"debug" => LogLevel::Debug,
"trace" => LogLevel::Trace,
_ => LogLevel::default(), // Default to Error if no match
}
}
}
impl ExecutorConfig {
pub fn builder() -> ExecutorConfigBuilder {
ExecutorConfigBuilder::default()
}
pub fn new(model_path: String) -> Self {
Self {
model_path,
log_level: LogLevel::Error,
enable_chunked_context: None,
normalize_log_probs: None,
iter_stats_max_iterations: None,
tensor_parallel_size: None,
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::{Context, Error, Result};
use bindings::nvllm_trt_engine_destroy;
use std::ffi::CStr;
use std::ffi::CString;
use std::ptr::NonNull;
use super::protocols;
use crate::kv_router::protocols::{ForwardPassMetrics, KvCacheEvents};
mod bindings {
#![allow(warnings, missing_docs)]
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
}
use bindings::{
nvllm_trt_engine, nvllm_trt_engine_await_iter_stats, nvllm_trt_engine_await_kv_events,
nvllm_trt_engine_await_responses, nvllm_trt_engine_cancel_request, nvllm_trt_engine_create,
nvllm_trt_engine_enqueue_request, nvllm_trt_engine_free_responses,
nvllm_trt_engine_has_completed, nvllm_trt_engine_is_ready, nvllm_trt_engine_shutdown,
};
use super::config;
#[derive(Debug, Clone)]
pub struct Executor {
engine: NonNull<nvllm_trt_engine>,
}
// nvllm_trt_engine is thread safe
// rust does not know that it is thread safe, so we have to tell it
unsafe impl Send for Executor {}
unsafe impl Sync for Executor {}
// The following implementation of ThreaadSafeEngine are the convenience methods used for call
// the C/C++ TensorRT API from Rust.
impl Executor {
/// Creates a new instance of the TensorRT LLM engine and takes ownership of the pointer to
/// the C/C++ TensorRT LLM engine object.
///
/// Executor implements the Drop trait, so this object is an RAII object and will
/// free the C/C++ TensorRT LLM engine object when it goes out of scope.
pub fn new(config: config::ExecutorConfig) -> Result<Self> {
let json = serde_json::to_string(&config)?;
let c_config = CString::new(json)?;
let engine = unsafe { nvllm_trt_engine_create(c_config.as_ptr()) };
let engine = NonNull::new(engine)
.ok_or_else(|| Error::msg("Failed to create nvllm_trt_engine".to_string()))?;
Ok(Self { engine })
}
/// Checks if the engine has started asking for new work
pub fn has_started(&self) -> bool {
let result = unsafe { nvllm_trt_engine_is_ready(self.engine.as_ptr()) };
if result != 0 {
return true;
}
false
}
/// Checks if the engine has completed all work and shutdown
pub fn has_completed(&self) -> bool {
let result = unsafe { nvllm_trt_engine_has_completed(self.engine.as_ptr()) };
if result != 0 {
return true;
}
false
}
/// Enqueues a request to the engine
/// The request it sent to the engine as a json encoded string; however, we reserve the right to change
/// the encoding in the future.
pub fn enqueue_request(&self, client_id: u64, request: CString) -> Result<u64> {
tracing::trace!("enqueuing request to trtllm engine");
let id = unsafe {
nvllm_trt_engine_enqueue_request(self.engine.as_ptr(), client_id, request.as_ptr())
};
if id == 0 {
return Err(Error::msg("Failed to enqueue request".to_string()));
}
Ok(id)
}
/// Block on [`nvllm_trt_engine_await_responses`] until a set response is received
/// If the server shutdown, the list of Responses will be empty
pub fn await_responses(&self) -> Result<protocols::Responses> {
let responses;
unsafe {
let ptr = nvllm_trt_engine_await_responses(self.engine.as_ptr());
let c_str = CStr::from_ptr(ptr);
let bytes = c_str.to_bytes();
responses = serde_json::from_slice(bytes).context("Failed to parse responses")?;
nvllm_trt_engine_free_responses(ptr);
}
Ok(responses)
}
pub fn await_kv_events(&self) -> Result<KvCacheEvents> {
let events: KvCacheEvents;
unsafe {
let ptr = nvllm_trt_engine_await_kv_events(self.engine.as_ptr());
if ptr.is_null() {
return Err(Error::msg(
"No KvEvents will be emitted for this model".to_string(),
));
}
let c_str = CStr::from_ptr(ptr);
let bytes = c_str.to_bytes();
events = serde_json::from_slice(bytes)
.context(format!("Failed to parse kv cache events: {:?}", c_str))?;
nvllm_trt_engine_free_responses(ptr);
}
Ok(events)
}
#[allow(dead_code)]
pub fn await_iter_stats(&self) -> Result<protocols::stats::IterStats> {
let stats: Vec<ForwardPassMetrics>;
unsafe {
let ptr = nvllm_trt_engine_await_iter_stats(self.engine.as_ptr());
if ptr.is_null() {
return Err(Error::msg(
"No iter stats will be emitted for this model".to_string(),
));
}
let c_str = CStr::from_ptr(ptr);
let bytes = c_str.to_bytes();
stats = serde_json::from_slice(bytes)
.context(format!("Failed to parse iter stats: {:?}", c_str))?;
nvllm_trt_engine_free_responses(ptr);
}
let stats = protocols::stats::IterStats { stats };
Ok(stats)
}
/// Cancels a request by its request_id
pub fn cancel_request(&self, request_id: u64) {
unsafe { nvllm_trt_engine_cancel_request(self.engine.as_ptr(), request_id) };
}
/// Shuts down the engine
pub fn shutdown(&self) {
unsafe { nvllm_trt_engine_shutdown(self.engine.as_ptr()) };
}
}
impl Drop for Executor {
fn drop(&mut self) {
unsafe {
nvllm_trt_engine_shutdown(self.engine.as_ptr());
nvllm_trt_engine_destroy(self.engine.as_ptr());
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::{Error, Result};
use async_trait::async_trait;
use futures::stream;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use triton_distributed_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use triton_distributed_runtime::pipeline::{ManyOut, SingleIn};
use triton_distributed_runtime::protocols::annotated::Annotated;
use super::Executor;
use crate::protocols::common::llm_backend::{BackendInput, LLMEngineOutput};
struct State {
request_id: String,
cancel_token: CancellationToken,
response_rx: mpsc::Receiver<Result<super::protocols::Output>>,
_link_to_cancel_task: tokio::sync::oneshot::Receiver<()>,
// set to true if we send what we expect to be a final message
// if the engine's response stream is closed before we send a final message, we can
// detect that condition and report an unknown error engine stream termination event
sentinel: bool,
}
// impl Drop for State {
// fn drop(&mut self) {
// tracing::trace!(request_id = self.stream.id(), "dropping state");
// }
// }
#[async_trait]
impl AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<LLMEngineOutput>>, Error> for Executor {
async fn generate(
&self,
request: SingleIn<BackendInput>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
// unpack the request and context
let (request, context) = request.into_parts();
// grab the core context
let context = context.context();
let context_cloned = context.clone();
// create a cancellation token and request id
let cancel_token = CancellationToken::new();
let request_id = context.id().to_string();
let mut engine_context = self.enqueue_request(request.into())?;
let (mut tx, rx) = tokio::sync::oneshot::channel::<()>();
let state = State {
request_id,
cancel_token: cancel_token.clone(),
_link_to_cancel_task: rx,
response_rx: engine_context
.take_response_rx()
.ok_or(Error::msg("no response rx"))?,
sentinel: false,
};
// create a task to monitor the the requests cancellation state
// todo: spawn on low priority async thread pool
tokio::spawn(async move {
tokio::select! {
_ = context.stopped() => {
tracing::debug!(request_id = context.id(), "request cancelled");
engine_context.cancel();
cancel_token.cancel();
}
_ = tx.closed() => {
tracing::debug!(request_id = context.id(), "response stream closed");
}
}
});
// create the response stream
let stream = stream::unfold(state, |mut state| async move {
if state.sentinel {
tracing::debug!(
request_id = state.request_id,
"sentinel set, closing stream"
);
return None;
}
// let output = tokio::select! {
let output = tokio::select! {
biased;
// await a response from the trtllm engine's response processor
output = state.response_rx.recv() => {
output
}
// if the stream is stopped, we need to:
// - cancel the request on the trtll engine
// - return an output with a finish reason of cancelled
// - mark the state as completed by setting the sentinel to true
_ = state.cancel_token.cancelled() => {
tracing::debug!(request_id = state.request_id, "request cancelled");
// state.engine.cancel();
state.sentinel = true;
let output = LLMEngineOutput::cancelled();
return Some((Annotated::from_data(output), state))
}
};
match output {
Some(Ok(output)) => {
if output.is_final {
tracing::debug!(request_id = state.request_id, "final response");
state.sentinel = true;
}
tracing::trace!(request_id = state.request_id, "issue response");
let output = LLMEngineOutput::from(output);
Some((Annotated::from_data(output), state))
}
Some(Err(err)) => {
tracing::debug!(request_id = state.request_id, "request failed: {:?}", err);
state.sentinel = true;
Some((Annotated::from_error(err.to_string()), state))
}
None => {
tracing::debug!(request_id = state.request_id, "request completed");
if !state.sentinel {
tracing::warn!(
request_id = state.request_id,
"engine stream terminated before final response or error"
);
state.sentinel = true;
Some((
Annotated::<LLMEngineOutput>::from_error(
"engine stream terminated before final response".to_string(),
),
state,
))
} else {
None
}
}
}
});
Ok(ResponseStream::new(Box::pin(stream), context_cloned))
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::{cpp, protocols};
use anyhow::Result;
use std::sync::Arc;
mod iteration;
mod kv;
mod response;
pub use iteration::{IterationProcessor, SubscriptionChannel as IterationStatsSubscriptionChannel};
pub use kv::{KvEventProcessor, KvEventSubscriptionChannel};
pub use response::ResponseProcessor;
#[derive(Debug)]
pub(crate) struct ProcessorState {
executor: Arc<cpp::Executor>,
}
impl ProcessorState {
pub fn new(executor: Arc<cpp::Executor>) -> Self {
Self { executor }
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::kv_router::protocols::ForwardPassMetrics;
use std::{
sync::{
atomic::{AtomicBool, Ordering},
Arc, Weak,
},
thread,
};
use tokio::sync::broadcast;
use super::*;
const CHANNEL_CAPACITY: usize = 256;
type ChannelType = broadcast::Sender<Arc<ForwardPassMetrics>>;
pub type SubscriptionChannel = broadcast::Receiver<Arc<ForwardPassMetrics>>;
pub struct IterationProcessor {
handle: thread::JoinHandle<()>,
shutdown: Arc<AtomicBool>,
channel: Weak<ChannelType>,
}
impl IterationProcessor {
/// Creates a new KV Event Processor
pub fn new(state: ProcessorState) -> Self {
// Shutdown Token
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_clone = shutdown.clone();
// Event Channel
let channel = Arc::new(broadcast::channel(CHANNEL_CAPACITY).0);
let channel_clone = channel.clone();
let handle = std::thread::spawn(move || {
process_events(state, shutdown_clone, channel_clone);
});
IterationProcessor {
handle,
shutdown,
channel: Arc::downgrade(&channel),
}
}
/// Subscribes to the KV Events broadcast channel
/// Multiple subscribers can be created to monitor the KV Events
pub fn subscribe(&self) -> Option<SubscriptionChannel> {
self.channel.upgrade().map(|channel| channel.subscribe())
}
/// Joins the thread and waits for it to finish
pub fn join(self) -> thread::Result<()> {
self.shutdown.store(true, Ordering::Relaxed);
self.handle.join()
}
}
fn process_events(state: ProcessorState, shutdown: Arc<AtomicBool>, channel: Arc<ChannelType>) {
loop {
// this blocks the thread until the response is ready or the server is shutdown
let iters = state
.executor
.await_iter_stats()
.expect("Failed to await responses");
let should_shutdown = shutdown.load(Ordering::Relaxed);
for iter in iters.stats {
tracing::debug!("Received iteration stats: {:?}", iter);
let iter = Arc::new(iter);
if let Err(e) = channel.send(iter) {
tracing::debug!("Failed to send message to channel: {:?}", e);
break;
}
}
if should_shutdown {
tracing::debug!("Shutting down KV Event Processor");
break;
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment