Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
675a9bf5
"vscode:/vscode.git/clone" did not exist on "dadda9ed5c0f190d06856c6ec15d552116203947"
Unverified
Commit
675a9bf5
authored
Apr 18, 2025
by
Graham King
Committed by
GitHub
Apr 18, 2025
Browse files
chore: Remove TRT-LLM C++ engine in favor of Python one (#747)
parent
d797b4ba
Changes
42
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
2270 deletions
+0
-2270
lib/bindings/cpp/nvllm-trt/src/engine_trt/request.cpp
lib/bindings/cpp/nvllm-trt/src/engine_trt/request.cpp
+0
-390
lib/bindings/cpp/nvllm-trt/src/engine_trt/request.hpp
lib/bindings/cpp/nvllm-trt/src/engine_trt/request.hpp
+0
-24
lib/bindings/cpp/nvllm-trt/src/engine_trt/response.cpp
lib/bindings/cpp/nvllm-trt/src/engine_trt/response.cpp
+0
-206
lib/bindings/cpp/nvllm-trt/src/engine_trt/response.hpp
lib/bindings/cpp/nvllm-trt/src/engine_trt/response.hpp
+0
-24
lib/bindings/cpp/nvllm-trt/src/engine_trt/stats.cpp
lib/bindings/cpp/nvllm-trt/src/engine_trt/stats.cpp
+0
-56
lib/bindings/cpp/nvllm-trt/src/engine_trt/stats.hpp
lib/bindings/cpp/nvllm-trt/src/engine_trt/stats.hpp
+0
-24
lib/bindings/cpp/nvllm-trt/src/nvllm_trt.cpp
lib/bindings/cpp/nvllm-trt/src/nvllm_trt.cpp
+0
-143
lib/engines/trtllm/Cargo.toml
lib/engines/trtllm/Cargo.toml
+0
-48
lib/engines/trtllm/build.rs
lib/engines/trtllm/build.rs
+0
-76
lib/engines/trtllm/src/executor.rs
lib/engines/trtllm/src/executor.rs
+0
-193
lib/engines/trtllm/src/executor/config.rs
lib/engines/trtllm/src/executor/config.rs
+0
-84
lib/engines/trtllm/src/executor/cpp.rs
lib/engines/trtllm/src/executor/cpp.rs
+0
-168
lib/engines/trtllm/src/executor/engine.rs
lib/engines/trtllm/src/executor/engine.rs
+0
-165
lib/engines/trtllm/src/executor/processors.rs
lib/engines/trtllm/src/executor/processors.rs
+0
-37
lib/engines/trtllm/src/executor/processors/iteration.rs
lib/engines/trtllm/src/executor/processors/iteration.rs
+0
-98
lib/engines/trtllm/src/executor/processors/kv.rs
lib/engines/trtllm/src/executor/processors/kv.rs
+0
-98
lib/engines/trtllm/src/executor/processors/response.rs
lib/engines/trtllm/src/executor/processors/response.rs
+0
-165
lib/engines/trtllm/src/executor/protocols.rs
lib/engines/trtllm/src/executor/protocols.rs
+0
-173
lib/engines/trtllm/src/executor/protocols/kv.rs
lib/engines/trtllm/src/executor/protocols/kv.rs
+0
-16
lib/engines/trtllm/src/executor/protocols/outputs.rs
lib/engines/trtllm/src/executor/protocols/outputs.rs
+0
-82
No files found.
lib/bindings/cpp/nvllm-trt/src/engine_trt/request.cpp
deleted
100644 → 0
View file @
d797b4ba
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "engine_trt/request.hpp"
#include <nlohmann/json.hpp>
#include <spdlog/spdlog.h>
#include <cstdint>
#include <optional>
#include <string>
#include <vector>
using
json
=
nlohmann
::
json
;
namespace
ex
=
tensorrt_llm
::
executor
;
namespace
nvidia
::
nvllm
::
trt
{
// SamplingConfig Struct
struct
SamplingConfig
{
uint32_t
beam_width
=
1
;
std
::
optional
<
uint32_t
>
top_k
;
std
::
optional
<
float
>
top_p
;
std
::
optional
<
float
>
top_p_min
;
std
::
optional
<
uint32_t
>
top_p_reset_ids
;
std
::
optional
<
float
>
top_p_decay
;
std
::
optional
<
uint32_t
>
seed
;
std
::
optional
<
float
>
temperature
;
std
::
optional
<
uint32_t
>
min_tokens
;
std
::
optional
<
float
>
beam_search_diversity_rate
;
std
::
optional
<
float
>
repetition_penalty
;
std
::
optional
<
float
>
presence_penalty
;
std
::
optional
<
float
>
frequency_penalty
;
std
::
optional
<
float
>
length_penalty
;
std
::
optional
<
uint32_t
>
early_stopping
;
std
::
optional
<
uint32_t
>
no_repeat_ngram_size
;
std
::
optional
<
uint32_t
>
num_return_sequences
;
ex
::
SamplingConfig
to_executor_config
()
const
{
return
ex
::
SamplingConfig
(
beam_width
,
top_k
,
top_p
,
top_p_min
,
top_p_reset_ids
,
top_p_decay
,
seed
,
temperature
,
min_tokens
,
beam_search_diversity_rate
,
repetition_penalty
,
presence_penalty
,
frequency_penalty
,
length_penalty
,
early_stopping
,
no_repeat_ngram_size
,
num_return_sequences
);
}
};
// Custom to_json and from_json functions for SamplingConfig
inline
void
to_json
(
json
&
j
,
const
SamplingConfig
&
s
)
{
j
=
json
{{
"beam_width"
,
s
.
beam_width
}};
if
(
s
.
top_k
)
j
[
"top_k"
]
=
s
.
top_k
.
value
();
if
(
s
.
top_p
)
j
[
"top_p"
]
=
s
.
top_p
.
value
();
if
(
s
.
top_p_min
)
j
[
"top_p_min"
]
=
s
.
top_p_min
.
value
();
if
(
s
.
top_p_reset_ids
)
j
[
"top_p_reset_ids"
]
=
s
.
top_p_reset_ids
.
value
();
if
(
s
.
top_p_decay
)
j
[
"top_p_decay"
]
=
s
.
top_p_decay
.
value
();
if
(
s
.
seed
)
j
[
"seed"
]
=
s
.
seed
.
value
();
if
(
s
.
temperature
)
j
[
"temperature"
]
=
s
.
temperature
.
value
();
if
(
s
.
min_tokens
)
j
[
"min_tokens"
]
=
s
.
min_tokens
.
value
();
if
(
s
.
beam_search_diversity_rate
)
j
[
"beam_search_diversity_rate"
]
=
s
.
beam_search_diversity_rate
.
value
();
if
(
s
.
repetition_penalty
)
j
[
"repetition_penalty"
]
=
s
.
repetition_penalty
.
value
();
if
(
s
.
presence_penalty
)
j
[
"presence_penalty"
]
=
s
.
presence_penalty
.
value
();
if
(
s
.
frequency_penalty
)
j
[
"frequency_penalty"
]
=
s
.
frequency_penalty
.
value
();
if
(
s
.
length_penalty
)
j
[
"length_penalty"
]
=
s
.
length_penalty
.
value
();
if
(
s
.
early_stopping
)
j
[
"early_stopping"
]
=
s
.
early_stopping
.
value
();
if
(
s
.
no_repeat_ngram_size
)
j
[
"no_repeat_ngram_size"
]
=
s
.
no_repeat_ngram_size
.
value
();
if
(
s
.
num_return_sequences
)
j
[
"num_return_sequences"
]
=
s
.
num_return_sequences
.
value
();
}
inline
void
from_json
(
const
json
&
j
,
SamplingConfig
&
s
)
{
j
.
at
(
"beam_width"
).
get_to
(
s
.
beam_width
);
if
(
j
.
contains
(
"top_k"
))
s
.
top_k
=
j
.
at
(
"top_k"
).
get
<
uint32_t
>
();
if
(
j
.
contains
(
"top_p"
))
s
.
top_p
=
j
.
at
(
"top_p"
).
get
<
float
>
();
if
(
j
.
contains
(
"top_p_min"
))
s
.
top_p_min
=
j
.
at
(
"top_p_min"
).
get
<
float
>
();
if
(
j
.
contains
(
"top_p_reset_ids"
))
s
.
top_p_reset_ids
=
j
.
at
(
"top_p_reset_ids"
).
get
<
uint32_t
>
();
if
(
j
.
contains
(
"top_p_decay"
))
s
.
top_p_decay
=
j
.
at
(
"top_p_decay"
).
get
<
float
>
();
if
(
j
.
contains
(
"seed"
))
s
.
seed
=
j
.
at
(
"seed"
).
get
<
uint32_t
>
();
if
(
j
.
contains
(
"temperature"
))
s
.
temperature
=
j
.
at
(
"temperature"
).
get
<
float
>
();
if
(
j
.
contains
(
"min_tokens"
))
s
.
min_tokens
=
j
.
at
(
"min_tokens"
).
get
<
uint32_t
>
();
if
(
j
.
contains
(
"beam_search_diversity_rate"
))
s
.
beam_search_diversity_rate
=
j
.
at
(
"beam_search_diversity_rate"
).
get
<
float
>
();
if
(
j
.
contains
(
"repetition_penalty"
))
s
.
repetition_penalty
=
j
.
at
(
"repetition_penalty"
).
get
<
float
>
();
if
(
j
.
contains
(
"presence_penalty"
))
s
.
presence_penalty
=
j
.
at
(
"presence_penalty"
).
get
<
float
>
();
if
(
j
.
contains
(
"frequency_penalty"
))
s
.
frequency_penalty
=
j
.
at
(
"frequency_penalty"
).
get
<
float
>
();
if
(
j
.
contains
(
"length_penalty"
))
s
.
length_penalty
=
j
.
at
(
"length_penalty"
).
get
<
float
>
();
if
(
j
.
contains
(
"early_stopping"
))
s
.
early_stopping
=
j
.
at
(
"early_stopping"
).
get
<
uint32_t
>
();
if
(
j
.
contains
(
"no_repeat_ngram_size"
))
s
.
no_repeat_ngram_size
=
j
.
at
(
"no_repeat_ngram_size"
).
get
<
uint32_t
>
();
if
(
j
.
contains
(
"num_return_sequences"
))
s
.
num_return_sequences
=
j
.
at
(
"num_return_sequences"
).
get
<
uint32_t
>
();
}
// OutputConfig Struct
struct
OutputConfig
{
bool
return_log_probs
;
bool
return_context_logits
;
bool
return_generation_logits
;
bool
exclude_input_from_output
;
bool
return_encoder_output
;
};
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE
(
OutputConfig
,
return_log_probs
,
return_context_logits
,
return_generation_logits
,
exclude_input_from_output
,
return_encoder_output
)
// RetentionPriorityAndDuration Struct
struct
RetentionPriorityAndDuration
{
std
::
optional
<
uint32_t
>
retention_priority
;
std
::
optional
<
uint64_t
>
duration_ms
;
};
inline
void
to_json
(
json
&
j
,
const
RetentionPriorityAndDuration
&
r
)
{
if
(
r
.
retention_priority
)
j
[
"retention_priority"
]
=
r
.
retention_priority
.
value
();
if
(
r
.
duration_ms
)
j
[
"duration_ms"
]
=
r
.
duration_ms
.
value
();
}
inline
void
from_json
(
const
json
&
j
,
RetentionPriorityAndDuration
&
r
)
{
if
(
j
.
contains
(
"retention_priority"
))
r
.
retention_priority
=
j
.
at
(
"retention_priority"
).
get
<
uint32_t
>
();
if
(
j
.
contains
(
"duration_ms"
))
r
.
duration_ms
=
j
.
at
(
"duration_ms"
).
get
<
uint64_t
>
();
}
// TokenRangeRetentionConfig Struct
struct
TokenRangeRetentionConfig
{
uint32_t
token_start
;
std
::
optional
<
uint32_t
>
token_end
;
uint32_t
priority
;
std
::
optional
<
uint64_t
>
duration_ms
;
};
inline
void
to_json
(
json
&
j
,
const
TokenRangeRetentionConfig
&
t
)
{
j
=
json
{{
"token_start"
,
t
.
token_start
},
{
"priority"
,
t
.
priority
}};
if
(
t
.
token_end
)
j
[
"token_end"
]
=
t
.
token_end
.
value
();
if
(
t
.
duration_ms
)
j
[
"duration_ms"
]
=
t
.
duration_ms
.
value
();
}
inline
void
from_json
(
const
json
&
j
,
TokenRangeRetentionConfig
&
t
)
{
j
.
at
(
"token_start"
).
get_to
(
t
.
token_start
);
j
.
at
(
"priority"
).
get_to
(
t
.
priority
);
if
(
j
.
contains
(
"token_end"
))
t
.
token_end
=
j
.
at
(
"token_end"
).
get
<
uint32_t
>
();
if
(
j
.
contains
(
"duration_ms"
))
t
.
duration_ms
=
j
.
at
(
"duration_ms"
).
get
<
uint64_t
>
();
}
// KvCacheRetentionConfig Struct
struct
KvCacheRetentionConfig
{
std
::
vector
<
TokenRangeRetentionConfig
>
token_range_retention_configs
;
uint32_t
decode_retention_priority
;
std
::
optional
<
uint64_t
>
decode_duration_ms
;
};
inline
void
to_json
(
json
&
j
,
const
KvCacheRetentionConfig
&
k
)
{
j
=
json
{{
"token_range_retention_configs"
,
k
.
token_range_retention_configs
},
{
"decode_retention_priority"
,
k
.
decode_retention_priority
}};
if
(
k
.
decode_duration_ms
)
j
[
"decode_duration_ms"
]
=
k
.
decode_duration_ms
.
value
();
}
inline
void
from_json
(
const
json
&
j
,
KvCacheRetentionConfig
&
k
)
{
j
.
at
(
"token_range_retention_configs"
).
get_to
(
k
.
token_range_retention_configs
);
j
.
at
(
"decode_retention_priority"
).
get_to
(
k
.
decode_retention_priority
);
if
(
j
.
contains
(
"decode_duration_ms"
))
k
.
decode_duration_ms
=
j
.
at
(
"decode_duration_ms"
).
get
<
uint64_t
>
();
}
// Request Struct
struct
Request
{
std
::
vector
<
int32_t
>
input_token_ids
;
uint32_t
max_tokens
;
bool
streaming
;
std
::
optional
<
SamplingConfig
>
sampling_config
;
std
::
optional
<
OutputConfig
>
output_config
;
std
::
optional
<
uint32_t
>
end_id
;
// std::optional<uint32_t> pad_id;
// std::vector<uint32_t> position_ids;
// std::vector<uint32_t> bad_words;
// std::vector<uint32_t> stop_words;
// std::vector<uint8_t> embedding_bias; // bytes
// // TODO: Add ExternalDraftTokensConfig external_draft_tokens_config;
// // TODO: Add PromptTuningConfig prompt_tuning_config;
// // TODO: Add LoraConfig lora_config;
// // TODO: Add LookaheadDecodingConfig lookahead_config;
// KvCacheRetentionConfig kv_cache_retention_config;
// std::string logits_post_processor_name;
// std::vector<uint32_t> encoder_input_token_ids;
// std::optional<uint64_t> client_id;
// bool return_all_generated_tokens;
// float priority;
// uint32_t request_type;
// // TODO: Add ContextPhaseParams context_phase_params;
// std::vector<uint8_t> encoder_input_features; // bytes
// std::optional<uint32_t> encoder_output_length;
// std::vector<uint8_t> cross_attention_mask; // bytes
// uint32_t num_return_sequences;
// // TODO: Add EagleConfig eagle_config;
// std::vector<uint8_t> skip_cross_attn_blocks; // bytes
};
// Custom to_json and from_json functions for Request
inline
void
to_json
(
json
&
j
,
const
Request
&
r
)
{
j
=
json
{
{
"input_token_ids"
,
r
.
input_token_ids
},
{
"max_tokens"
,
r
.
max_tokens
},
{
"streaming"
,
r
.
streaming
},
// {"sampling_config", r.sampling_config},
// {"output_config", r.output_config},
// {"position_ids", r.position_ids},
// {"bad_words", r.bad_words},
// {"stop_words", r.stop_words},
// {"kv_cache_retention_config", r.kv_cache_retention_config},
// {"logits_post_processor_name", r.logits_post_processor_name},
// {"encoder_input_token_ids", r.encoder_input_token_ids},
// {"return_all_generated_tokens", r.return_all_generated_tokens},
// {"priority", r.priority},
// {"request_type", r.request_type},
// {"num_return_sequences", r.num_return_sequences}
};
if
(
r
.
sampling_config
)
j
[
"sampling_config"
]
=
r
.
sampling_config
.
value
();
if
(
r
.
output_config
)
j
[
"output_config"
]
=
r
.
output_config
.
value
();
if
(
r
.
end_id
)
j
[
"end_id"
]
=
r
.
end_id
.
value
();
// if (r.pad_id)
// j["pad_id"] = r.pad_id.value();
// if (!r.embedding_bias.empty())
// j["embedding_bias"] = r.embedding_bias;
// if (r.client_id)
// j["client_id"] = r.client_id.value();
// if (!r.encoder_input_features.empty())
// j["encoder_input_features"] = r.encoder_input_features;
// if (r.encoder_output_length)
// j["encoder_output_length"] = r.encoder_output_length.value();
// if (!r.cross_attention_mask.empty())
// j["cross_attention_mask"] = r.cross_attention_mask;
// if (!r.skip_cross_attn_blocks.empty())
// j["skip_cross_attn_blocks"] = r.skip_cross_attn_blocks;
}
inline
void
from_json
(
const
json
&
j
,
Request
&
r
)
{
j
.
at
(
"input_token_ids"
).
get_to
(
r
.
input_token_ids
);
j
.
at
(
"max_tokens"
).
get_to
(
r
.
max_tokens
);
j
.
at
(
"streaming"
).
get_to
(
r
.
streaming
);
if
(
j
.
contains
(
"sampling_config"
))
r
.
sampling_config
=
j
.
at
(
"sampling_config"
).
get
<
SamplingConfig
>
();
if
(
j
.
contains
(
"output_config"
))
r
.
output_config
=
j
.
at
(
"output_config"
).
get
<
OutputConfig
>
();
// j.at("sampling_config").get_to(r.sampling_config);
// j.at("output_config").get_to(r.output_config);
// j.at("position_ids").get_to(r.position_ids);
// j.at("bad_words").get_to(r.bad_words);
// j.at("stop_words").get_to(r.stop_words);
// j.at("kv_cache_retention_config").get_to(r.kv_cache_retention_config);
// j.at("logits_post_processor_name").get_to(r.logits_post_processor_name);
// j.at("encoder_input_token_ids").get_to(r.encoder_input_token_ids);
// j.at("return_all_generated_tokens").get_to(r.return_all_generated_tokens);
// j.at("priority").get_to(r.priority);
// j.at("request_type").get_to(r.request_type);
// j.at("num_return_sequences").get_to(r.num_return_sequences);
if
(
j
.
contains
(
"end_id"
))
r
.
end_id
=
j
.
at
(
"end_id"
).
get
<
uint32_t
>
();
// if (j.contains("pad_id"))
// r.pad_id = j.at("pad_id").get<uint32_t>();
// if (j.contains("embedding_bias"))
// r.embedding_bias = j.at("embedding_bias").get<std::vector<uint8_t>>();
// if (j.contains("client_id"))
// r.client_id = j.at("client_id").get<uint64_t>();
// if (j.contains("encoder_input_features"))
// r.encoder_input_features = j.at("encoder_input_features").get<std::vector<uint8_t>>();
// if (j.contains("encoder_output_length"))
// r.encoder_output_length = j.at("encoder_output_length").get<uint32_t>();
// if (j.contains("cross_attention_mask"))
// r.cross_attention_mask = j.at("cross_attention_mask").get<std::vector<uint8_t>>();
// if (j.contains("skip_cross_attn_blocks"))
// r.skip_cross_attn_blocks = j.at("skip_cross_attn_blocks").get<std::vector<uint8_t>>();
}
tensorrt_llm
::
executor
::
Request
deserialize_request
(
const
std
::
string
&
request_proto
)
{
spdlog
::
trace
(
"Deserializing request json: {}"
,
request_proto
);
auto
j
=
json
::
parse
(
request_proto
);
auto
req_in
=
j
.
get
<
Request
>
();
spdlog
::
trace
(
"constructing request with {} input tokens; max tokens: {}"
,
req_in
.
input_token_ids
.
size
(),
req_in
.
max_tokens
);
tensorrt_llm
::
executor
::
Request
request
(
std
::
move
(
req_in
.
input_token_ids
),
req_in
.
max_tokens
,
true
);
if
(
req_in
.
sampling_config
)
{
spdlog
::
trace
(
"Setting sampling_config"
);
request
.
setSamplingConfig
(
req_in
.
sampling_config
->
to_executor_config
());
}
if
(
req_in
.
end_id
)
{
spdlog
::
trace
(
"Setting end_id: {}"
,
req_in
.
end_id
.
value
());
request
.
setEndId
(
req_in
.
end_id
.
value
());
}
return
request
;
}
}
// namespace nvidia::nvllm::trt
lib/bindings/cpp/nvllm-trt/src/engine_trt/request.hpp
deleted
100644 → 0
View file @
d797b4ba
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tensorrt_llm/executor/executor.h"
namespace
nvidia
::
nvllm
::
trt
{
tensorrt_llm
::
executor
::
Request
deserialize_request
(
const
std
::
string
&
request
);
}
// namespace nvidia::nvllm::trt
lib/bindings/cpp/nvllm-trt/src/engine_trt/response.cpp
deleted
100644 → 0
View file @
d797b4ba
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "engine_trt/response.hpp"
#include <nlohmann/json.hpp>
#include <spdlog/spdlog.h>
#include <optional>
#include <string>
#include <vector>
using
json
=
nlohmann
::
json
;
namespace
ex
=
tensorrt_llm
::
executor
;
namespace
nvidia
::
nvllm
::
trt
{
// Forward declarations
struct
Response
;
struct
Output
;
enum
FinishReasonEnum
{
FINISH_REASON_NOT_DONE
=
0
,
FINISH_REASON_EOS
=
1
,
FINISH_REASON_STOP
=
2
,
FINISH_REASON_LENGTH
=
3
,
};
// Output Struct
struct
Output
{
bool
is_final
;
std
::
vector
<
int32_t
>
token_ids
;
std
::
optional
<
float
>
cum_log_prob
;
std
::
optional
<
std
::
vector
<
float
>>
log_probs
;
std
::
optional
<
FinishReasonEnum
>
finish_reason
;
};
// Custom to_json function
void
to_json
(
json
&
j
,
const
Output
&
o
)
{
j
=
json
{{
"is_final"
,
o
.
is_final
},
{
"token_ids"
,
o
.
token_ids
}};
if
(
o
.
cum_log_prob
)
{
j
[
"cum_log_prob"
]
=
*
o
.
cum_log_prob
;
}
if
(
o
.
log_probs
)
{
j
[
"log_probs"
]
=
*
o
.
log_probs
;
}
if
(
o
.
finish_reason
)
{
j
[
"finish_reason"
]
=
static_cast
<
int
>
(
*
o
.
finish_reason
);
}
}
void
from_json
(
const
json
&
j
,
Output
&
o
)
{
j
.
at
(
"is_final"
).
get_to
(
o
.
is_final
);
j
.
at
(
"token_ids"
).
get_to
(
o
.
token_ids
);
if
(
j
.
contains
(
"cum_log_prob"
)
&&
!
j
[
"cum_log_prob"
].
is_null
())
{
o
.
cum_log_prob
=
j
[
"cum_log_prob"
].
get
<
float
>
();
}
else
{
o
.
cum_log_prob
=
std
::
nullopt
;
}
if
(
j
.
contains
(
"log_probs"
)
&&
!
j
[
"log_probs"
].
is_null
())
{
o
.
log_probs
=
j
[
"log_probs"
].
get
<
std
::
vector
<
float
>>
();
}
else
{
o
.
log_probs
=
std
::
nullopt
;
}
if
(
j
.
contains
(
"finish_reason"
)
&&
!
j
[
"finish_reason"
].
is_null
())
{
o
.
finish_reason
=
static_cast
<
FinishReasonEnum
>
(
j
[
"finish_reason"
].
get
<
int
>
());
}
else
{
o
.
finish_reason
=
std
::
nullopt
;
}
}
// Response Struct
struct
Response
{
uint64_t
request_id
;
std
::
optional
<
uint64_t
>
client_id
;
// Optional client ID.
std
::
optional
<
std
::
string
>
error_msg
;
std
::
optional
<
Output
>
output
;
};
inline
void
to_json
(
json
&
j
,
const
Response
&
p
)
{
j
=
json
{{
"request_id"
,
p
.
request_id
}};
if
(
p
.
client_id
)
j
[
"client_id"
]
=
p
.
client_id
.
value
();
if
(
p
.
error_msg
)
j
[
"error_msg"
]
=
p
.
error_msg
.
value
();
if
(
p
.
output
)
j
[
"output"
]
=
p
.
output
.
value
();
}
inline
void
from_json
(
const
json
&
j
,
Response
&
p
)
{
j
.
at
(
"request_id"
).
get_to
(
p
.
request_id
);
if
(
j
.
contains
(
"client_id"
))
p
.
client_id
=
j
.
at
(
"client_id"
).
get
<
uint64_t
>
();
if
(
j
.
contains
(
"error_msg"
))
p
.
error_msg
=
j
.
at
(
"error_msg"
).
get
<
std
::
string
>
();
if
(
j
.
contains
(
"output"
))
p
.
output
=
j
.
at
(
"output"
).
get
<
Output
>
();
}
// Responses Struct
struct
Responses
{
std
::
vector
<
Response
>
responses
;
bool
shutdown
;
};
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE
(
Responses
,
responses
,
shutdown
)
Response
convert
(
ex
::
Response
&&
response
)
{
auto
request_id
=
response
.
getRequestId
();
auto
client_id
=
response
.
getClientId
();
if
(
response
.
hasError
())
{
auto
error_msg
=
response
.
getErrorMsg
();
return
Response
{
request_id
,
client_id
,
{
error_msg
},
std
::
nullopt
};
}
auto
e_output
=
response
.
getResult
();
auto
is_final
=
e_output
.
isFinal
;
assert
(
e_output
.
outputTokenIds
.
size
()
==
1
);
auto
token_ids
=
std
::
move
(
e_output
.
outputTokenIds
[
0
]);
auto
output
=
Output
{
is_final
,
std
::
move
(
token_ids
),
std
::
nullopt
,
std
::
nullopt
,
std
::
nullopt
};
if
(
e_output
.
cumLogProbs
.
has_value
())
{
assert
(
e_output
.
cumLogProbs
.
value
().
size
()
==
1
);
output
.
cum_log_prob
=
{
e_output
.
cumLogProbs
.
value
()[
0
]};
}
if
(
e_output
.
logProbs
.
has_value
())
{
assert
(
e_output
.
logProbs
.
value
().
size
()
==
1
);
output
.
log_probs
=
{
std
::
move
(
e_output
.
logProbs
.
value
()[
0
])};
}
if
(
e_output
.
finishReasons
.
size
()
>
0
)
{
assert
(
e_output
.
finishReasons
.
size
()
==
1
);
auto
finish_reason
=
static_cast
<
FinishReasonEnum
>
(
e_output
.
finishReasons
[
0
]);
if
(
finish_reason
!=
FinishReasonEnum
::
FINISH_REASON_NOT_DONE
)
{
output
.
finish_reason
=
{
finish_reason
};
}
}
return
Response
{
request_id
,
client_id
,
std
::
nullopt
,
{
output
}};
}
std
::
string
serialize_responses
(
std
::
deque
<
ex
::
Response
>
responses
,
bool
shutdown
)
{
auto
object
=
Responses
{};
object
.
shutdown
=
shutdown
;
while
(
!
responses
.
empty
())
{
auto
response
=
std
::
move
(
responses
.
front
());
responses
.
pop_front
();
auto
r
=
convert
(
std
::
move
(
response
));
assert
(
r
.
output
.
has_value
()
||
r
.
error_msg
.
has_value
());
object
.
responses
.
emplace_back
(
std
::
move
(
r
));
}
return
json
(
object
).
dump
();
}
}
// namespace nvidia::nvllm::trt
lib/bindings/cpp/nvllm-trt/src/engine_trt/response.hpp
deleted
100644 → 0
View file @
d797b4ba
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tensorrt_llm/executor/executor.h"
namespace
nvidia
::
nvllm
::
trt
{
std
::
string
serialize_responses
(
std
::
deque
<
tensorrt_llm
::
executor
::
Response
>
responses
,
bool
shutdown
);
}
// namespace nvidia::nvllm::trt
lib/bindings/cpp/nvllm-trt/src/engine_trt/stats.cpp
deleted
100644 → 0
View file @
d797b4ba
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "engine_trt/stats.hpp"
#include <nlohmann/json.hpp>
#include <deque>
using
json
=
nlohmann
::
json
;
namespace
nvidia
::
nvllm
::
trt
{
std
::
string
serialize_iter_stats
(
std
::
deque
<::
tensorrt_llm
::
executor
::
IterationStats
>
stats
)
{
json
json_stats
=
json
::
array
();
for
(
const
auto
&
stat
:
stats
)
{
if
(
stat
.
kvCacheStats
.
has_value
())
{
json
entry
;
entry
[
"iter"
]
=
stat
.
iter
;
entry
[
"kv_active_blocks"
]
=
stat
.
kvCacheStats
->
usedNumBlocks
;
entry
[
"kv_total_blocks"
]
=
stat
.
kvCacheStats
->
maxNumBlocks
;
entry
[
"request_active_slots"
]
=
stat
.
numActiveRequests
;
entry
[
"request_total_slots"
]
=
stat
.
maxNumActiveRequests
;
entry
[
"request_new_active_slots"
]
=
stat
.
numNewActiveRequests
;
json_stats
.
push_back
(
entry
);
}
else
{
json
entry
;
entry
[
"iter"
]
=
stat
.
iter
;
entry
[
"request_active_slots"
]
=
stat
.
numActiveRequests
;
entry
[
"request_total_slots"
]
=
stat
.
maxNumActiveRequests
;
entry
[
"request_new_active_slots"
]
=
stat
.
numNewActiveRequests
;
json_stats
.
push_back
(
entry
);
}
}
return
json_stats
.
dump
();
}
}
// namespace nvidia::nvllm::trt
lib/bindings/cpp/nvllm-trt/src/engine_trt/stats.hpp
deleted
100644 → 0
View file @
d797b4ba
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tensorrt_llm/executor/executor.h"
namespace
nvidia
::
nvllm
::
trt
{
std
::
string
serialize_iter_stats
(
std
::
deque
<
tensorrt_llm
::
executor
::
IterationStats
>
stats
);
}
// namespace nvidia::nvllm::trt
lib/bindings/cpp/nvllm-trt/src/nvllm_trt.cpp
deleted
100644 → 0
View file @
d797b4ba
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nvidia/nvllm/nvllm_trt.h"
#include "api/engine.hpp"
#include <cstring>
extern
"C"
{
// int trtllm_mpi_session_set_communicator(void* world_comm_ptr)
// {
// return nvidia::nvllm::trt::MpiSession::set_communicator(world_comm_ptr);
// }
nvllm_trt_engine_t
nvllm_trt_engine_create
(
const
char
*
config_proto
)
{
// based on the type of engine, we might choose to create a different concrete engine object
try
{
return
reinterpret_cast
<
nvllm_trt_engine_t
>
(
new
nvidia
::
nvllm
::
trt
::
StreamingEngine
(
std
::
string
(
config_proto
)));
}
catch
(
const
std
::
exception
&
e
)
{
printf
(
"Caught exception when initializing tensorrt_llm: %s
\n
"
,
e
.
what
());
return
nullptr
;
}
}
nvllm_trt_engine_t
nvllm_trt_engine_unsafe_create_from_executor
(
void
*
engine
)
{
try
{
return
reinterpret_cast
<
nvllm_trt_engine_t
>
(
new
nvidia
::
nvllm
::
trt
::
StreamingEngine
(
engine
));
}
catch
(
const
std
::
exception
&
e
)
{
printf
(
"Caught exception when initializing from raw pointer: %s
\n
"
,
e
.
what
());
return
nullptr
;
}
}
request_id_t
nvllm_trt_engine_enqueue_request
(
nvllm_trt_engine_t
engine
,
client_id_t
client_id
,
const
char
*
req_proto
)
{
// Call the enqueue_request method on the C++ class
try
{
return
reinterpret_cast
<
nvidia
::
nvllm
::
trt
::
StreamingEngine
*>
(
engine
)
->
enqueue_request
(
client_id
,
std
::
string
(
req_proto
));
}
catch
(...)
{
return
0
;
}
}
char
*
nvllm_trt_engine_await_responses
(
nvllm_trt_engine_t
engine
)
{
auto
responses
=
reinterpret_cast
<
nvidia
::
nvllm
::
trt
::
StreamingEngine
*>
(
engine
)
->
await_responses
();
char
*
c_responses
=
strdup
(
responses
.
c_str
());
// Allocate memory and copy the string
return
c_responses
;
// Return the C string (remember to free this in the calling code)
}
char
*
nvllm_trt_engine_await_kv_events
(
nvllm_trt_engine_t
engine
)
{
auto
responses
=
reinterpret_cast
<
nvidia
::
nvllm
::
trt
::
StreamingEngine
*>
(
engine
)
->
await_kv_events
();
if
(
!
responses
)
{
return
nullptr
;
}
char
*
c_responses
=
strdup
(
responses
->
c_str
());
// Allocate memory and copy the string
return
c_responses
;
// Return the C string (remember to free this in the calling code)
}
// Get basic iteration stats
char
*
nvllm_trt_engine_await_iter_stats
(
nvllm_trt_engine_t
engine
)
{
auto
responses
=
reinterpret_cast
<
nvidia
::
nvllm
::
trt
::
StreamingEngine
*>
(
engine
)
->
await_iter_stats
();
if
(
!
responses
)
{
return
nullptr
;
}
char
*
c_responses
=
strdup
(
responses
->
c_str
());
return
c_responses
;
}
void
nvllm_trt_engine_free_responses
(
char
*
responses
)
{
free
(
responses
);
}
void
nvllm_trt_engine_cancel_request
(
nvllm_trt_engine_t
engine
,
uint64_t
request_id
)
{
reinterpret_cast
<
nvidia
::
nvllm
::
trt
::
StreamingEngine
*>
(
engine
)
->
cancel_request
(
request_id
);
}
void
nvllm_trt_engine_shutdown
(
nvllm_trt_engine_t
engine
)
{
reinterpret_cast
<
nvidia
::
nvllm
::
trt
::
StreamingEngine
*>
(
engine
)
->
shutdown
();
}
int
nvllm_trt_engine_destroy
(
nvllm_trt_engine_t
engine
)
{
auto
*
trtllm_engine
=
reinterpret_cast
<
nvidia
::
nvllm
::
trt
::
StreamingEngine
*>
(
engine
);
delete
trtllm_engine
;
return
NVLLM_TRT_ENGINE_SUCCESS
;
}
int
nvllm_trt_engine_is_ready
(
nvllm_trt_engine_t
engine
)
{
return
reinterpret_cast
<
nvidia
::
nvllm
::
trt
::
StreamingEngine
*>
(
engine
)
->
is_ready
();
}
int
nvllm_trt_engine_has_completed
(
nvllm_trt_engine_t
engine
)
{
return
reinterpret_cast
<
nvidia
::
nvllm
::
trt
::
StreamingEngine
*>
(
engine
)
->
has_completed
();
}
// int trtllm_version_major()
// {
// return TRTLLM_VERSION_MAJOR;
// }
// int trtllm_version_minor()
// {
// return TRTLLM_VERSION_MINOR;
// }
// int trtllm_version_patch()
// {
// return TRTLLM_VERSION_PATCH;
// }
}
lib/engines/trtllm/Cargo.toml
deleted
100644 → 0
View file @
d797b4ba
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
[package]
name
=
"dynamo-engine-trtllm"
version.workspace
=
true
edition.workspace
=
true
description.workspace
=
true
authors.workspace
=
true
license.workspace
=
true
homepage.workspace
=
true
repository.workspace
=
true
keywords.workspace
=
true
[dependencies]
dynamo-runtime
=
{
workspace
=
true
}
dynamo-llm
=
{
workspace
=
true
}
anyhow
=
{
workspace
=
true
}
async-stream
=
{
workspace
=
true
}
async-trait
=
{
workspace
=
true
}
derive_builder
=
{
workspace
=
true
}
futures
=
{
workspace
=
true
}
serde
=
{
workspace
=
true
}
serde_json
=
{
workspace
=
true
}
thiserror
=
{
workspace
=
true
}
tokio
=
{
workspace
=
true
}
tokio-util
=
{
workspace
=
true
}
tracing
=
{
workspace
=
true
}
async-openai
=
"0.27.2"
serde_repr
=
"0.1"
[build-dependencies]
bindgen
=
"0.70"
cmake
=
"0.1"
lib/engines/trtllm/build.rs
deleted
100644 → 0
View file @
d797b4ba
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
fn
main
()
{
extern
crate
bindgen
;
use
cmake
::
Config
;
use
std
::
env
;
use
std
::
path
::
PathBuf
;
let
installed_headers
=
"/usr/local/include/nvidia/nvllm/nvllm_trt.h"
;
let
local_headers
=
"../../bindings/cpp/nvllm-trt/include/nvidia/nvllm/nvllm_trt.h"
;
let
headers_path
;
if
PathBuf
::
from
(
installed_headers
)
.exists
()
{
headers_path
=
installed_headers
;
println!
(
"cargo:warning=nvllm found. Building with installed version..."
);
println!
(
"cargo:rustc-link-search=native=/usr/local/lib"
);
println!
(
"cargo:rustc-link-search=native=/opt/tensorrt_llm/lib"
);
println!
(
"cargo:rustc-link-lib=dylib=nvllm_trt"
);
println!
(
"cargo:rustc-link-lib=dylib=tensorrt_llm"
);
println!
(
"cargo:rustc-link-lib=dylib=tensorrt_llm_nvrtc_wrapper"
);
println!
(
"cargo:rustc-link-lib=dylib=nvinfer_plugin_tensorrt_llm"
);
println!
(
"cargo:rustc-link-lib=dylib=decoder_attention"
);
println!
(
"cargo:rerun-if-changed=/usr/local/lib"
);
}
else
if
PathBuf
::
from
(
local_headers
)
.exists
()
{
headers_path
=
local_headers
;
println!
(
"cargo:warning=nvllm not found. Building stub version..."
);
let
dst
=
Config
::
new
(
"../../bindings/cpp/nvllm-trt"
)
.define
(
"USE_STUBS"
,
"ON"
)
.no_build_target
(
true
)
.build
();
println!
(
"cargo:warning=building stubs in {}"
,
dst
.display
());
let
dst
=
dst
.canonicalize
()
.unwrap
();
println!
(
"cargo:rustc-link-search=native={}/build"
,
dst
.display
());
println!
(
"cargo:rustc-link-lib=dylib=nvllm_trt"
);
println!
(
"cargo:rustc-link-lib=dylib=tensorrt_llm"
);
println!
(
"cargo:rerun-if-changed=../bindings/cpp/nvllm-trt"
);
}
else
{
panic!
(
"nvllm_trt.h not found"
);
}
// generate bindings for the trtllm c api
let
bindings
=
bindgen
::
Builder
::
default
()
.header
(
headers_path
)
.generate
()
.expect
(
"Unable to generate bindings"
);
// Write the bindings to a file
let
out_path
=
PathBuf
::
from
(
env
::
var
(
"OUT_DIR"
)
.unwrap
());
bindings
.write_to_file
(
out_path
.join
(
"bindings.rs"
))
.expect
(
"Could not write bindings!"
);
// // Build protobuf
// tonic_build::configure()
// .build_server(false)
// .compile_protos(&["../../proto/trtllm.proto"], &["../../proto"])
// .expect("Failed to compile protos");
}
lib/engines/trtllm/src/executor.rs
deleted
100644 → 0
View file @
d797b4ba
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
mod
cpp
;
mod
engine
;
mod
processors
;
// pub mod protos {
// include!(concat!(env!("OUT_DIR"), "/nvidia.nvllm.trt.proto.rs"));
// }
pub
mod
protocols
;
pub
mod
config
;
use
anyhow
::
Result
;
use
std
::{
collections
::
HashMap
,
ffi
::
CString
,
sync
::{
atomic
::
AtomicU64
,
Arc
,
Mutex
,
OnceLock
,
Weak
},
};
use
tokio
::
sync
::
mpsc
;
use
processors
::{
IterationProcessor
,
IterationStatsSubscriptionChannel
,
KvEventProcessor
,
KvEventSubscriptionChannel
,
ProcessorState
,
ResponseProcessor
,
};
pub
struct
Executor
{
executor
:
Arc
<
cpp
::
Executor
>
,
next_id
:
AtomicU64
,
response_queues
:
ResponseQueues
,
response_processor
:
OnceLock
<
ResponseProcessor
>
,
kv_event_processor
:
OnceLock
<
KvEventProcessor
>
,
iteration_processor
:
OnceLock
<
IterationProcessor
>
,
}
type
ResponseQueues
=
Arc
<
Mutex
<
HashMap
<
u64
,
mpsc
::
Sender
<
Result
<
protocols
::
Output
>>>>>
;
impl
Executor
{
pub
fn
from_model_path
<
P
:
ToString
>
(
model_path
:
P
)
->
Result
<
Self
>
{
let
config
=
config
::
ExecutorConfig
::
new
(
model_path
.to_string
());
Self
::
new
(
config
)
}
pub
fn
new
(
config
:
config
::
ExecutorConfig
)
->
Result
<
Self
>
{
Ok
(
Self
{
executor
:
Arc
::
new
(
cpp
::
Executor
::
new
(
config
)
?
),
next_id
:
AtomicU64
::
new
(
0
),
response_queues
:
Arc
::
new
(
Mutex
::
new
(
HashMap
::
new
())),
response_processor
:
OnceLock
::
new
(),
kv_event_processor
:
OnceLock
::
new
(),
iteration_processor
:
OnceLock
::
new
(),
})
}
pub
fn
has_started
(
&
self
)
->
bool
{
self
.executor
.has_started
()
}
pub
fn
has_completed
(
&
self
)
->
bool
{
self
.executor
.has_completed
()
}
pub
fn
enqueue_request
(
&
self
,
request
:
protocols
::
Request
)
->
Result
<
ExecutionContext
>
{
let
client_id
=
self
.next_id
.fetch_add
(
1
,
std
::
sync
::
atomic
::
Ordering
::
Relaxed
);
let
(
tx
,
rx
)
=
mpsc
::
channel
(
128
);
self
.response_queues
.lock
()
.expect
(
"response_queues lock poisoned"
)
.insert
(
client_id
,
tx
);
let
json
=
serde_json
::
to_string
(
&
request
)
?
;
let
str
=
CString
::
new
(
json
)
?
;
let
request_id
=
self
.executor
.enqueue_request
(
client_id
,
str
)
.inspect_err
(|
_
|
{
self
.response_queues
.lock
()
.expect
(
"response_queues lock poisoned"
)
.remove
(
&
client_id
);
})
?
;
println!
(
"request_id: {}"
,
request_id
);
Ok
(
ExecutionContext
{
request_id
,
response_rx
:
Some
(
rx
),
executor
:
Arc
::
downgrade
(
&
self
.executor
),
})
}
pub
fn
cancel_request
(
&
self
,
client_id
:
u64
)
{
self
.executor
.cancel_request
(
client_id
)
}
/// Start a background task to process responses from the TensorRT LLM AsyncEngine
pub
fn
start_response_processor
(
&
self
)
{
self
.response_processor
.get_or_init
(||
{
ResponseProcessor
::
new
(
self
.create_processor
(),
self
.response_queues
.clone
())
});
}
/// Starts a background task to process kv events
/// TODO - check the TensorRT LLM config and only start this if the server is configured to send kv events
pub
fn
start_kv_event_processor
(
&
self
)
{
self
.kv_event_processor
.get_or_init
(||
KvEventProcessor
::
new
(
self
.create_processor
()));
}
/// Starts a background task to process forward pass / iteration statistics
pub
fn
start_iteration_metrics_processor
(
&
self
)
{
self
.iteration_processor
.get_or_init
(||
IterationProcessor
::
new
(
self
.create_processor
()));
}
/// Subscribes to the KV Events broadcast channel
pub
fn
subscribe_to_kv_events
(
&
self
)
->
Result
<
KvEventSubscriptionChannel
>
{
self
.kv_event_processor
.get_or_init
(||
KvEventProcessor
::
new
(
self
.create_processor
()))
.subscribe
()
.ok_or
(
anyhow
::
anyhow!
(
"Failed to subscribe to KV events"
))
}
pub
fn
subscribe_to_iteration_stats
(
&
self
)
->
Result
<
IterationStatsSubscriptionChannel
>
{
self
.iteration_processor
.get_or_init
(||
IterationProcessor
::
new
(
self
.create_processor
()))
.subscribe
()
.ok_or
(
anyhow
::
anyhow!
(
"Failed to subscribe to iteration stats"
))
}
/// Issues a shutdown request to the TensorRT LLM AsyncEngine
/// This is a blocking call. After the async engine has shutdown each background processor/thread/task
/// will be joined and the resources will be released.
pub
fn
shutdown
(
&
mut
self
)
{
self
.executor
.shutdown
();
self
.response_processor
.take
()
.map
(|
p
|
p
.join
());
self
.kv_event_processor
.take
()
.map
(|
p
|
p
.join
());
self
.iteration_processor
.take
()
.map
(|
p
|
p
.join
());
}
/// Constructs a new ProcessorState instance which packages up any bits from the Executor for the processor task
fn
create_processor
(
&
self
)
->
ProcessorState
{
ProcessorState
::
new
(
self
.executor
.clone
())
}
}
impl
Drop
for
Executor
{
fn
drop
(
&
mut
self
)
{
self
.shutdown
();
}
}
pub
struct
ExecutionContext
{
/// Internal TensorRT LLM request_id; used to cancel the request
/// This value is present in the response but because we do not know it before hand, it is only used for cancellation
request_id
:
u64
,
/// Hold a weak pointer to the executor for cancellation
executor
:
Weak
<
cpp
::
Executor
>
,
/// Response stream associated with this request
response_rx
:
Option
<
mpsc
::
Receiver
<
Result
<
protocols
::
Output
>>>
,
}
impl
ExecutionContext
{
pub
fn
cancel
(
&
self
)
{
if
let
Some
(
executor
)
=
self
.executor
.upgrade
()
{
executor
.cancel_request
(
self
.request_id
);
}
}
pub
fn
take_response_rx
(
&
mut
self
)
->
Option
<
mpsc
::
Receiver
<
Result
<
protocols
::
Output
>>>
{
self
.response_rx
.take
()
}
}
lib/engines/trtllm/src/executor/config.rs
deleted
100644 → 0
View file @
d797b4ba
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
derive_builder
::
Builder
;
use
serde
::{
Deserialize
,
Serialize
};
#[derive(Debug,
Clone,
Serialize,
Deserialize,
Default,
Builder)]
pub
struct
ExecutorConfig
{
model_path
:
String
,
#[builder(default
=
"LogLevel::Error"
)]
log_level
:
LogLevel
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default)]
enable_chunked_context
:
Option
<
bool
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default)]
normalize_log_probs
:
Option
<
bool
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default)]
iter_stats_max_iterations
:
Option
<
u32
>
,
/// The number of processes for tensor parallelism. Defaults to 1.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default)]
tensor_parallel_size
:
Option
<
u32
>
,
}
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
#[serde(rename_all
=
"lowercase"
)]
#[derive(Default)]
pub
enum
LogLevel
{
#[default]
Error
,
Warn
,
Info
,
Debug
,
Trace
,
}
impl
From
<&
str
>
for
LogLevel
{
fn
from
(
value
:
&
str
)
->
Self
{
match
value
.to_lowercase
()
.as_str
()
{
"error"
=>
LogLevel
::
Error
,
"warn"
=>
LogLevel
::
Warn
,
"info"
=>
LogLevel
::
Info
,
"debug"
=>
LogLevel
::
Debug
,
"trace"
=>
LogLevel
::
Trace
,
_
=>
LogLevel
::
default
(),
// Default to Error if no match
}
}
}
impl
ExecutorConfig
{
pub
fn
builder
()
->
ExecutorConfigBuilder
{
ExecutorConfigBuilder
::
default
()
}
pub
fn
new
(
model_path
:
String
)
->
Self
{
Self
{
model_path
,
log_level
:
LogLevel
::
Error
,
enable_chunked_context
:
None
,
normalize_log_probs
:
None
,
iter_stats_max_iterations
:
None
,
tensor_parallel_size
:
None
,
}
}
}
lib/engines/trtllm/src/executor/cpp.rs
deleted
100644 → 0
View file @
d797b4ba
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
anyhow
::{
Context
,
Error
,
Result
};
use
bindings
::
nvllm_trt_engine_destroy
;
use
std
::
ffi
::
CStr
;
use
std
::
ffi
::
CString
;
use
std
::
ptr
::
NonNull
;
use
super
::
protocols
;
use
dynamo_llm
::
kv_router
::
protocols
::{
ForwardPassMetrics
,
KvCacheEvents
};
mod
bindings
{
#![allow(warnings,
missing_docs)]
include!
(
concat!
(
env!
(
"OUT_DIR"
),
"/bindings.rs"
));
}
use
bindings
::{
nvllm_trt_engine
,
nvllm_trt_engine_await_iter_stats
,
nvllm_trt_engine_await_kv_events
,
nvllm_trt_engine_await_responses
,
nvllm_trt_engine_cancel_request
,
nvllm_trt_engine_create
,
nvllm_trt_engine_enqueue_request
,
nvllm_trt_engine_free_responses
,
nvllm_trt_engine_has_completed
,
nvllm_trt_engine_is_ready
,
nvllm_trt_engine_shutdown
,
};
use
super
::
config
;
#[derive(Debug,
Clone)]
pub
struct
Executor
{
engine
:
NonNull
<
nvllm_trt_engine
>
,
}
// nvllm_trt_engine is thread safe
// rust does not know that it is thread safe, so we have to tell it
unsafe
impl
Send
for
Executor
{}
unsafe
impl
Sync
for
Executor
{}
// The following implementation of ThreaadSafeEngine are the convenience methods used for call
// the C/C++ TensorRT API from Rust.
impl
Executor
{
/// Creates a new instance of the TensorRT LLM engine and takes ownership of the pointer to
/// the C/C++ TensorRT LLM engine object.
///
/// Executor implements the Drop trait, so this object is an RAII object and will
/// free the C/C++ TensorRT LLM engine object when it goes out of scope.
pub
fn
new
(
config
:
config
::
ExecutorConfig
)
->
Result
<
Self
>
{
let
json
=
serde_json
::
to_string
(
&
config
)
?
;
let
c_config
=
CString
::
new
(
json
)
?
;
let
engine
=
unsafe
{
nvllm_trt_engine_create
(
c_config
.as_ptr
())
};
let
engine
=
NonNull
::
new
(
engine
)
.ok_or_else
(||
Error
::
msg
(
"Failed to create nvllm_trt_engine"
.to_string
()))
?
;
Ok
(
Self
{
engine
})
}
/// Checks if the engine has started asking for new work
pub
fn
has_started
(
&
self
)
->
bool
{
let
result
=
unsafe
{
nvllm_trt_engine_is_ready
(
self
.engine
.as_ptr
())
};
if
result
!=
0
{
return
true
;
}
false
}
/// Checks if the engine has completed all work and shutdown
pub
fn
has_completed
(
&
self
)
->
bool
{
let
result
=
unsafe
{
nvllm_trt_engine_has_completed
(
self
.engine
.as_ptr
())
};
if
result
!=
0
{
return
true
;
}
false
}
/// Enqueues a request to the engine
/// The request it sent to the engine as a json encoded string; however, we reserve the right to change
/// the encoding in the future.
pub
fn
enqueue_request
(
&
self
,
client_id
:
u64
,
request
:
CString
)
->
Result
<
u64
>
{
tracing
::
trace!
(
"enqueuing request to trtllm engine"
);
let
id
=
unsafe
{
nvllm_trt_engine_enqueue_request
(
self
.engine
.as_ptr
(),
client_id
,
request
.as_ptr
())
};
if
id
==
0
{
return
Err
(
Error
::
msg
(
"Failed to enqueue request"
.to_string
()));
}
Ok
(
id
)
}
/// Block on [`nvllm_trt_engine_await_responses`] until a set response is received
/// If the server shutdown, the list of Responses will be empty
pub
fn
await_responses
(
&
self
)
->
Result
<
protocols
::
Responses
>
{
let
responses
;
unsafe
{
let
ptr
=
nvllm_trt_engine_await_responses
(
self
.engine
.as_ptr
());
let
c_str
=
CStr
::
from_ptr
(
ptr
);
let
bytes
=
c_str
.to_bytes
();
responses
=
serde_json
::
from_slice
(
bytes
)
.context
(
"Failed to parse responses"
)
?
;
nvllm_trt_engine_free_responses
(
ptr
);
}
Ok
(
responses
)
}
pub
fn
await_kv_events
(
&
self
)
->
Result
<
KvCacheEvents
>
{
let
events
:
KvCacheEvents
;
unsafe
{
let
ptr
=
nvllm_trt_engine_await_kv_events
(
self
.engine
.as_ptr
());
if
ptr
.is_null
()
{
return
Err
(
Error
::
msg
(
"No KvEvents will be emitted for this model"
.to_string
(),
));
}
let
c_str
=
CStr
::
from_ptr
(
ptr
);
let
bytes
=
c_str
.to_bytes
();
events
=
serde_json
::
from_slice
(
bytes
)
.context
(
format!
(
"Failed to parse kv cache events: {:?}"
,
c_str
))
?
;
nvllm_trt_engine_free_responses
(
ptr
);
}
Ok
(
events
)
}
#[allow(dead_code)]
pub
fn
await_iter_stats
(
&
self
)
->
Result
<
protocols
::
stats
::
IterStats
>
{
let
stats
:
Vec
<
ForwardPassMetrics
>
;
unsafe
{
let
ptr
=
nvllm_trt_engine_await_iter_stats
(
self
.engine
.as_ptr
());
if
ptr
.is_null
()
{
return
Err
(
Error
::
msg
(
"No iter stats will be emitted for this model"
.to_string
(),
));
}
let
c_str
=
CStr
::
from_ptr
(
ptr
);
let
bytes
=
c_str
.to_bytes
();
stats
=
serde_json
::
from_slice
(
bytes
)
.context
(
format!
(
"Failed to parse iter stats: {:?}"
,
c_str
))
?
;
nvllm_trt_engine_free_responses
(
ptr
);
}
let
stats
=
protocols
::
stats
::
IterStats
{
stats
};
Ok
(
stats
)
}
/// Cancels a request by its request_id
pub
fn
cancel_request
(
&
self
,
request_id
:
u64
)
{
unsafe
{
nvllm_trt_engine_cancel_request
(
self
.engine
.as_ptr
(),
request_id
)
};
}
/// Shuts down the engine
pub
fn
shutdown
(
&
self
)
{
unsafe
{
nvllm_trt_engine_shutdown
(
self
.engine
.as_ptr
())
};
}
}
impl
Drop
for
Executor
{
fn
drop
(
&
mut
self
)
{
unsafe
{
nvllm_trt_engine_shutdown
(
self
.engine
.as_ptr
());
nvllm_trt_engine_destroy
(
self
.engine
.as_ptr
());
}
}
}
lib/engines/trtllm/src/executor/engine.rs
deleted
100644 → 0
View file @
d797b4ba
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
anyhow
::{
Error
,
Result
};
use
async_trait
::
async_trait
;
use
dynamo_runtime
::
engine
::{
AsyncEngine
,
AsyncEngineContextProvider
,
ResponseStream
};
use
dynamo_runtime
::
pipeline
::{
ManyOut
,
SingleIn
};
use
dynamo_runtime
::
protocols
::
annotated
::
Annotated
;
use
futures
::
stream
;
use
tokio
::
sync
::
mpsc
;
use
tokio_util
::
sync
::
CancellationToken
;
use
dynamo_llm
::
protocols
::
common
::
llm_backend
::{
BackendInput
,
LLMEngineOutput
};
use
super
::
Executor
;
struct
State
{
request_id
:
String
,
cancel_token
:
CancellationToken
,
response_rx
:
mpsc
::
Receiver
<
Result
<
super
::
protocols
::
Output
>>
,
_
link_to_cancel_task
:
tokio
::
sync
::
oneshot
::
Receiver
<
()
>
,
// set to true if we send what we expect to be a final message
// if the engine's response stream is closed before we send a final message, we can
// detect that condition and report an unknown error engine stream termination event
sentinel
:
bool
,
}
// impl Drop for State {
// fn drop(&mut self) {
// tracing::trace!(request_id = self.stream.id(), "dropping state");
// }
// }
#[async_trait]
impl
AsyncEngine
<
SingleIn
<
BackendInput
>
,
ManyOut
<
Annotated
<
LLMEngineOutput
>>
,
Error
>
for
Executor
{
async
fn
generate
(
&
self
,
request
:
SingleIn
<
BackendInput
>
,
)
->
Result
<
ManyOut
<
Annotated
<
LLMEngineOutput
>>
,
Error
>
{
// unpack the request and context
let
(
request
,
context
)
=
request
.into_parts
();
// grab the core context
let
context
=
context
.context
();
let
context_cloned
=
context
.clone
();
// create a cancellation token and request id
let
cancel_token
=
CancellationToken
::
new
();
let
request_id
=
context
.id
()
.to_string
();
let
mut
engine_context
=
self
.enqueue_request
(
request
.into
())
?
;
let
(
mut
tx
,
rx
)
=
tokio
::
sync
::
oneshot
::
channel
::
<
()
>
();
let
state
=
State
{
request_id
,
cancel_token
:
cancel_token
.clone
(),
_
link_to_cancel_task
:
rx
,
response_rx
:
engine_context
.take_response_rx
()
.ok_or
(
Error
::
msg
(
"no response rx"
))
?
,
sentinel
:
false
,
};
// create a task to monitor the the requests cancellation state
// todo: spawn on low priority async thread pool
tokio
::
spawn
(
async
move
{
tokio
::
select!
{
_
=
context
.stopped
()
=>
{
tracing
::
debug!
(
request_id
=
context
.id
(),
"request cancelled"
);
engine_context
.cancel
();
cancel_token
.cancel
();
}
_
=
tx
.closed
()
=>
{
tracing
::
debug!
(
request_id
=
context
.id
(),
"response stream closed"
);
}
}
});
// create the response stream
let
stream
=
stream
::
unfold
(
state
,
|
mut
state
|
async
move
{
if
state
.sentinel
{
tracing
::
debug!
(
request_id
=
state
.request_id
,
"sentinel set, closing stream"
);
return
None
;
}
// let output = tokio::select! {
let
output
=
tokio
::
select!
{
biased
;
// await a response from the trtllm engine's response processor
output
=
state
.response_rx
.recv
()
=>
{
output
}
// if the stream is stopped, we need to:
// - cancel the request on the trtll engine
// - return an output with a finish reason of cancelled
// - mark the state as completed by setting the sentinel to true
_
=
state
.cancel_token
.cancelled
()
=>
{
tracing
::
debug!
(
request_id
=
state
.request_id
,
"request cancelled"
);
// state.engine.cancel();
state
.sentinel
=
true
;
let
output
=
LLMEngineOutput
::
cancelled
();
return
Some
((
Annotated
::
from_data
(
output
),
state
))
}
};
match
output
{
Some
(
Ok
(
output
))
=>
{
if
output
.is_final
{
tracing
::
debug!
(
request_id
=
state
.request_id
,
"final response"
);
state
.sentinel
=
true
;
}
tracing
::
trace!
(
request_id
=
state
.request_id
,
"issue response"
);
let
output
=
LLMEngineOutput
::
from
(
output
);
Some
((
Annotated
::
from_data
(
output
),
state
))
}
Some
(
Err
(
err
))
=>
{
tracing
::
debug!
(
request_id
=
state
.request_id
,
"request failed: {:?}"
,
err
);
state
.sentinel
=
true
;
Some
((
Annotated
::
from_error
(
err
.to_string
()),
state
))
}
None
=>
{
tracing
::
debug!
(
request_id
=
state
.request_id
,
"request completed"
);
if
!
state
.sentinel
{
tracing
::
warn!
(
request_id
=
state
.request_id
,
"engine stream terminated before final response or error"
);
state
.sentinel
=
true
;
Some
((
Annotated
::
<
LLMEngineOutput
>
::
from_error
(
"engine stream terminated before final response"
.to_string
(),
),
state
,
))
}
else
{
None
}
}
}
});
Ok
(
ResponseStream
::
new
(
Box
::
pin
(
stream
),
context_cloned
))
}
}
lib/engines/trtllm/src/executor/processors.rs
deleted
100644 → 0
View file @
d797b4ba
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
super
::{
cpp
,
protocols
};
use
anyhow
::
Result
;
use
std
::
sync
::
Arc
;
mod
iteration
;
mod
kv
;
mod
response
;
pub
use
iteration
::{
IterationProcessor
,
SubscriptionChannel
as
IterationStatsSubscriptionChannel
};
pub
use
kv
::{
KvEventProcessor
,
KvEventSubscriptionChannel
};
pub
use
response
::
ResponseProcessor
;
#[derive(Debug)]
pub
(
crate
)
struct
ProcessorState
{
executor
:
Arc
<
cpp
::
Executor
>
,
}
impl
ProcessorState
{
pub
fn
new
(
executor
:
Arc
<
cpp
::
Executor
>
)
->
Self
{
Self
{
executor
}
}
}
lib/engines/trtllm/src/executor/processors/iteration.rs
deleted
100644 → 0
View file @
d797b4ba
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
dynamo_llm
::
kv_router
::
protocols
::
ForwardPassMetrics
;
use
std
::{
sync
::{
atomic
::{
AtomicBool
,
Ordering
},
Arc
,
Weak
,
},
thread
,
};
use
tokio
::
sync
::
broadcast
;
use
super
::
*
;
const
CHANNEL_CAPACITY
:
usize
=
256
;
type
ChannelType
=
broadcast
::
Sender
<
Arc
<
ForwardPassMetrics
>>
;
pub
type
SubscriptionChannel
=
broadcast
::
Receiver
<
Arc
<
ForwardPassMetrics
>>
;
pub
struct
IterationProcessor
{
handle
:
thread
::
JoinHandle
<
()
>
,
shutdown
:
Arc
<
AtomicBool
>
,
channel
:
Weak
<
ChannelType
>
,
}
impl
IterationProcessor
{
/// Creates a new KV Event Processor
pub
fn
new
(
state
:
ProcessorState
)
->
Self
{
// Shutdown Token
let
shutdown
=
Arc
::
new
(
AtomicBool
::
new
(
false
));
let
shutdown_clone
=
shutdown
.clone
();
// Event Channel
let
channel
=
Arc
::
new
(
broadcast
::
channel
(
CHANNEL_CAPACITY
)
.0
);
let
channel_clone
=
channel
.clone
();
let
handle
=
std
::
thread
::
spawn
(
move
||
{
process_events
(
state
,
shutdown_clone
,
channel_clone
);
});
IterationProcessor
{
handle
,
shutdown
,
channel
:
Arc
::
downgrade
(
&
channel
),
}
}
/// Subscribes to the KV Events broadcast channel
/// Multiple subscribers can be created to monitor the KV Events
pub
fn
subscribe
(
&
self
)
->
Option
<
SubscriptionChannel
>
{
self
.channel
.upgrade
()
.map
(|
channel
|
channel
.subscribe
())
}
/// Joins the thread and waits for it to finish
pub
fn
join
(
self
)
->
thread
::
Result
<
()
>
{
self
.shutdown
.store
(
true
,
Ordering
::
Relaxed
);
self
.handle
.join
()
}
}
fn
process_events
(
state
:
ProcessorState
,
shutdown
:
Arc
<
AtomicBool
>
,
channel
:
Arc
<
ChannelType
>
)
{
loop
{
// this blocks the thread until the response is ready or the server is shutdown
let
iters
=
state
.executor
.await_iter_stats
()
.expect
(
"Failed to await responses"
);
let
should_shutdown
=
shutdown
.load
(
Ordering
::
Relaxed
);
for
iter
in
iters
.stats
{
tracing
::
debug!
(
"Received iteration stats: {:?}"
,
iter
);
let
iter
=
Arc
::
new
(
iter
);
if
let
Err
(
e
)
=
channel
.send
(
iter
)
{
tracing
::
debug!
(
"Failed to send message to channel: {:?}"
,
e
);
break
;
}
}
if
should_shutdown
{
tracing
::
debug!
(
"Shutting down KV Event Processor"
);
break
;
}
}
}
lib/engines/trtllm/src/executor/processors/kv.rs
deleted
100644 → 0
View file @
d797b4ba
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
dynamo_llm
::
kv_router
::
protocols
::
KvCacheEvents
;
use
std
::{
sync
::{
atomic
::{
AtomicBool
,
Ordering
},
Arc
,
Weak
,
},
thread
,
};
use
tokio
::
sync
::
broadcast
;
use
super
::
*
;
const
KV_EVENT_CHANNEL_CAPACITY
:
usize
=
65536
;
type
EventChannelType
=
broadcast
::
Sender
<
KvCacheEvents
>
;
pub
type
KvEventSubscriptionChannel
=
broadcast
::
Receiver
<
KvCacheEvents
>
;
pub
struct
KvEventProcessor
{
handle
:
thread
::
JoinHandle
<
()
>
,
shutdown
:
Arc
<
AtomicBool
>
,
channel
:
Weak
<
EventChannelType
>
,
}
impl
KvEventProcessor
{
/// Creates a new KV Event Processor
pub
fn
new
(
state
:
ProcessorState
)
->
Self
{
// Shutdown Token
let
shutdown
=
Arc
::
new
(
AtomicBool
::
new
(
false
));
let
shutdown_clone
=
shutdown
.clone
();
// Event Channel
let
channel
=
Arc
::
new
(
broadcast
::
channel
(
KV_EVENT_CHANNEL_CAPACITY
)
.0
);
let
channel_clone
=
channel
.clone
();
let
handle
=
std
::
thread
::
spawn
(
move
||
{
process_events
(
state
,
shutdown_clone
,
channel_clone
);
});
KvEventProcessor
{
handle
,
shutdown
,
channel
:
Arc
::
downgrade
(
&
channel
),
}
}
/// Subscribes to the KV Events broadcast channel
/// Multiple subscribers can be created to monitor the KV Events
pub
fn
subscribe
(
&
self
)
->
Option
<
broadcast
::
Receiver
<
KvCacheEvents
>>
{
self
.channel
.upgrade
()
.map
(|
channel
|
channel
.subscribe
())
}
/// Joins the thread and waits for it to finish
pub
fn
join
(
self
)
->
thread
::
Result
<
()
>
{
self
.shutdown
.store
(
true
,
Ordering
::
Relaxed
);
self
.handle
.join
()
}
}
fn
process_events
(
state
:
ProcessorState
,
shutdown
:
Arc
<
AtomicBool
>
,
channel
:
Arc
<
EventChannelType
>
,
)
{
loop
{
// this blocks the thread until the response is ready or the server is shutdown
let
mut
message
=
state
.executor
.await_kv_events
()
.expect
(
"Failed to await responses"
);
let
should_shutdown
=
message
.shutdown
||
shutdown
.load
(
Ordering
::
Relaxed
);
message
.shutdown
=
should_shutdown
;
if
let
Err
(
e
)
=
channel
.send
(
message
)
{
tracing
::
debug!
(
"Failed to send message to channel: {:?}"
,
e
);
}
if
should_shutdown
{
tracing
::
debug!
(
"Shutting down KV Event Processor"
);
break
;
}
}
}
lib/engines/trtllm/src/executor/processors/response.rs
deleted
100644 → 0
View file @
d797b4ba
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
std
::
thread
;
use
tokio
::
sync
::
mpsc
;
use
super
::
*
;
use
crate
::
executor
::
ResponseQueues
;
pub
struct
ResponseProcessor
{
handle
:
thread
::
JoinHandle
<
()
>
,
}
impl
ResponseProcessor
{
pub
fn
new
(
state
:
ProcessorState
,
response_queues
:
ResponseQueues
)
->
Self
{
let
handle
=
std
::
thread
::
spawn
(
move
||
{
process_responses
(
state
,
response_queues
);
});
ResponseProcessor
{
handle
}
}
/// Block and wait for the response processor to finish
pub
fn
join
(
self
)
->
thread
::
Result
<
()
>
{
self
.handle
.join
()
}
}
#[derive(Debug,
thiserror::Error)]
enum
ResponseError
{
#[error(
"Response queue dropped; possible client disconnect"
)]
ResponseQueueDropped
,
#[error(
"Response channel closed; possible client disconnect"
)]
ChannelClosed
,
#[error(
"Response channel full; backpress detected in response stream"
)]
ChannelFull
,
#[error(
"Invalid response: no error or result found"
)]
InvalidResponse
,
/// Error indicating that TensorRT LLM returned an error
/// This also indicates that the request was not successful and no further responses
/// will be sent for this request
#[error(
"TensorRT LLM Engine Error: {0}"
)]
EngineError
(
String
),
#[error(
"Completed successfully"
)]
RequestComplete
,
}
fn
process_responses
(
state
:
ProcessorState
,
response_queues
:
ResponseQueues
)
{
loop
{
// this blocks the thread until the response is ready or the server is shutdown
let
message
=
state
.executor
.await_responses
()
.expect
(
"Failed to await responses"
);
// check shutdown condition
if
message
.shutdown
{
tracing
::
info!
(
"Server shutdown detected"
);
break
;
}
// process responses - hold the lock while we iterate to avoid any contention
// grabbing and releasing it for each response
let
mut
queues
=
response_queues
.lock
()
.unwrap
();
for
output
in
message
.responses
{
let
request_id
=
output
.request_id
;
let
client_id
=
output
.client_id
.expect
(
"client_id is missing"
);
let
tx
=
queues
.get
(
&
client_id
);
match
try_send
(
tx
,
output
)
{
Ok
(
_
)
=>
{}
Err
(
e
)
=>
{
tracing
::
trace!
(
client_id
,
"processing response: {}"
,
e
);
match
e
{
ResponseError
::
InvalidResponse
=>
{
// this would likely be a bug on the server; we expect the oneof to be set
tracing
::
warn!
(
client_id
,
"Invalid response; No action required"
);
}
ResponseError
::
EngineError
(
_
)
=>
{
// no need to cancel, the server will not send any more responses
queues
.remove
(
&
client_id
);
}
ResponseError
::
ChannelFull
=>
{
// critical error
tracing
::
error!
(
client_id
,
"Alert: backpressure detected in response stream"
);
state
.executor
.cancel_request
(
request_id
);
queues
.remove
(
&
client_id
);
}
ResponseError
::
ChannelClosed
=>
{
// the first indication the client has disconnected
state
.executor
.cancel_request
(
request_id
);
queues
.remove
(
&
client_id
);
}
ResponseError
::
ResponseQueueDropped
=>
{
// if we get a response for a dropped queue, we need to cancel the request
state
.executor
.cancel_request
(
request_id
);
}
ResponseError
::
RequestComplete
=>
{
// no need to cancel, the server will not send any more responses
queues
.remove
(
&
client_id
);
}
}
}
}
}
}
}
fn
try_send
(
tx
:
Option
<&
mpsc
::
Sender
<
Result
<
protocols
::
Output
>>>
,
response
:
protocols
::
Response
,
)
->
Result
<
(),
ResponseError
>
{
let
mut
rc
=
Ok
(());
let
tx
=
tx
.ok_or
(
ResponseError
::
ResponseQueueDropped
)
?
;
let
result
=
match
(
response
.output
,
response
.error_msg
)
{
(
Some
(
output
),
None
)
=>
{
if
output
.is_final
{
rc
=
Err
(
ResponseError
::
RequestComplete
);
}
Ok
(
output
)
}
(
None
,
Some
(
e
))
=>
{
rc
=
Err
(
ResponseError
::
EngineError
(
e
.clone
()));
Err
(
ResponseError
::
EngineError
(
e
.clone
()))
}
(
None
,
None
)
=>
return
Err
(
ResponseError
::
InvalidResponse
),
(
Some
(
_
),
Some
(
_
))
=>
return
Err
(
ResponseError
::
InvalidResponse
),
};
match
tx
.try_send
(
result
.map_err
(|
e
|
e
.into
()))
{
Ok
(
_
)
=>
{}
Err
(
e
)
=>
match
e
{
mpsc
::
error
::
TrySendError
::
Closed
(
_
)
=>
{
return
Err
(
ResponseError
::
ChannelClosed
);
}
mpsc
::
error
::
TrySendError
::
Full
(
_
)
=>
{
return
Err
(
ResponseError
::
ChannelFull
);
}
},
}
rc
}
lib/engines/trtllm/src/executor/protocols.rs
deleted
100644 → 0
View file @
d797b4ba
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
derive_builder
::
Builder
;
use
serde
::{
Deserialize
,
Serialize
};
use
serde_repr
::{
Deserialize_repr
,
Serialize_repr
};
pub
mod
kv
;
pub
mod
outputs
;
pub
mod
stats
;
pub
use
outputs
::
*
;
#[derive(Serialize,
Deserialize,
Default)]
pub
struct
SamplingConfig
{
pub
beam_width
:
u32
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_k
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_p
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_p_min
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_p_reset_ids
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_p_decay
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
seed
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
temperature
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
min_tokens
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
beam_search_diversity_rate
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
repetition_penalty
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
presence_penalty
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
frequency_penalty
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
length_penalty
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
early_stopping
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
no_repeat_ngram_size
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
num_return_sequences
:
Option
<
u32
>
,
}
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
struct
OutputConfig
{
pub
return_log_probs
:
bool
,
pub
return_context_logits
:
bool
,
pub
return_generation_logits
:
bool
,
pub
exclude_input_from_output
:
bool
,
pub
return_encoder_output
:
bool
,
}
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
struct
RetentionPriorityAndDuration
{
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
retention_priority
:
Option
<
u32
>
,
// google.protobuf.UInt32Value
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
duration_ms
:
Option
<
u64
>
,
// google.protobuf.UInt64Value
}
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
struct
TokenRangeRetentionConfig
{
pub
token_start
:
u32
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
token_end
:
Option
<
u32
>
,
// google.protobuf.UInt32Value
pub
priority
:
u32
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
duration_ms
:
Option
<
u64
>
,
// google.protobuf.UInt64Value
}
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
struct
KvCacheRetentionConfig
{
pub
token_range_retention_configs
:
Vec
<
TokenRangeRetentionConfig
>
,
pub
decode_retention_priority
:
u32
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
decode_duration_ms
:
Option
<
u64
>
,
// google.protobuf.UInt64Value
}
#[derive(Serialize,
Deserialize,
Debug,
Clone,
Builder)]
pub
struct
Request
{
pub
input_token_ids
:
Vec
<
u32
>
,
pub
max_tokens
:
u32
,
pub
streaming
:
bool
,
// pub sampling_config: SamplingConfig,
// pub output_config: OutputConfig,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
end_id
:
Option
<
u32
>
,
// pub pad_id: Option<u32>, // google.protobuf.UInt32Value
// pub position_ids: Vec<u32>,
// pub bad_words: Vec<u32>,
// pub stop_words: Vec<u32>,
// pub embedding_bias: Vec<u8>, // bytes
// // TODO: Add external_draft_tokens_config: ExternalDraftTokensConfig
// // TODO: Add prompt_tuning_config: PromptTuningConfig
// // TODO: Add lora_config: LoraConfig
// // TODO: Add lookahead_config: LookaheadDecodingConfig
// pub kv_cache_retention_config: KvCacheRetentionConfig,
// pub logits_post_processor_name: String,
// pub encoder_input_token_ids: Vec<u32>,
// pub client_id: Option<u64>, // google.protobuf.UInt64Value
// pub return_all_generated_tokens: bool,
// pub priority: f32,
// pub request_type: u32,
// // TODO: Add context_phase_params: ContextPhaseParams
// pub encoder_input_features: Vec<u8>, // bytes
// pub encoder_output_length: Option<u32>, // google.protobuf.UInt32Value
// pub cross_attention_mask: Vec<u8>, // bytes
// pub num_return_sequences: u32,
// // TODO: Add eagle_config: EagleConfig
// pub skip_cross_attn_blocks: Vec<u8>, // bytes
}
// todo - return a Result
impl
Request
{
pub
fn
new
(
input_token_ids
:
Vec
<
u32
>
,
max_tokens
:
u32
)
->
Self
{
RequestBuilder
::
default
()
.input_token_ids
(
input_token_ids
)
.max_tokens
(
max_tokens
)
.streaming
(
true
)
.build
()
.unwrap
()
}
}
// todo convert to a TryFrom
impl
From
<
dynamo_llm
::
protocols
::
common
::
llm_backend
::
BackendInput
>
for
Request
{
fn
from
(
input
:
dynamo_llm
::
protocols
::
common
::
llm_backend
::
BackendInput
)
->
Self
{
let
request
=
RequestBuilder
::
default
()
.input_token_ids
(
input
.token_ids
)
.max_tokens
(
input
.stop_conditions.max_tokens
.unwrap_or
(
16
))
.streaming
(
true
)
.end_id
(
input
.eos_token_ids
.last
()
.cloned
())
.build
()
.unwrap
();
request
}
}
lib/engines/trtllm/src/executor/protocols/kv.rs
deleted
100644 → 0
View file @
d797b4ba
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub
use
dynamo_llm
::
kv_router
::
protocols
::
ForwardPassMetrics
;
lib/engines/trtllm/src/executor/protocols/outputs.rs
deleted
100644 → 0
View file @
d797b4ba
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
super
::
*
;
use
dynamo_llm
::
protocols
::{
common
::{
self
},
TokenIdType
,
};
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
struct
Responses
{
pub
responses
:
Vec
<
Response
>
,
pub
shutdown
:
bool
,
}
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
struct
Response
{
pub
request_id
:
u64
,
pub
client_id
:
Option
<
u64
>
,
// Optional client ID.
pub
error_msg
:
Option
<
String
>
,
// Error message if the request failed.
pub
output
:
Option
<
Output
>
,
// Output if the request succeeded.
}
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
struct
Output
{
pub
is_final
:
bool
,
pub
token_ids
:
Vec
<
TokenIdType
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
cum_log_prob
:
Option
<
f64
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
log_probs
:
Option
<
Vec
<
f64
>>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
finish_reason
:
Option
<
FinishReasonEnum
>
,
}
#[derive(Serialize_repr,
Deserialize_repr,
Debug,
Clone)]
#[repr(u8)]
pub
enum
FinishReasonEnum
{
FinishReasonNotDone
=
0
,
FinishReasonEos
=
1
,
FinishReasonStop
=
2
,
FinishReasonLength
=
3
,
}
impl
From
<
Output
>
for
common
::
llm_backend
::
LLMEngineOutput
{
fn
from
(
output
:
Output
)
->
Self
{
let
finish_reason
=
match
output
.finish_reason
{
Some
(
FinishReasonEnum
::
FinishReasonNotDone
)
=>
None
,
Some
(
FinishReasonEnum
::
FinishReasonEos
)
=>
Some
(
common
::
FinishReason
::
EoS
),
Some
(
FinishReasonEnum
::
FinishReasonStop
)
=>
Some
(
common
::
FinishReason
::
Stop
),
Some
(
FinishReasonEnum
::
FinishReasonLength
)
=>
Some
(
common
::
FinishReason
::
Length
),
None
=>
None
,
};
common
::
llm_backend
::
LLMEngineOutput
{
// todo - propagate mdcsum
token_ids
:
output
.token_ids
,
tokens
:
None
,
text
:
None
,
cum_log_probs
:
output
.cum_log_prob
,
log_probs
:
None
,
finish_reason
,
}
}
}
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment