"lib/bindings/vscode:/vscode.git/clone" did not exist on "e159e53fe6127330e62efe855b3c310da9e4183b"
Commit 057f8f47 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: TensorRT-LLM engine (#317)

Engine, `tio` support and docs.

Proof of concept / experimental.
parent 11a36651
...@@ -447,6 +447,26 @@ version = "1.6.0" ...@@ -447,6 +447,26 @@ version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
[[package]]
name = "bindgen"
version = "0.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f49d8fed880d473ea71efb9bf597651e77201bdd4893efe54c9e5d65ae04ce6f"
dependencies = [
"bitflags 2.8.0",
"cexpr",
"clang-sys",
"itertools 0.12.1",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash 1.1.0",
"shlex",
"syn 2.0.98",
]
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "1.3.2" version = "1.3.2"
...@@ -542,6 +562,15 @@ dependencies = [ ...@@ -542,6 +562,15 @@ dependencies = [
"shlex", "shlex",
] ]
[[package]]
name = "cexpr"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
dependencies = [
"nom",
]
[[package]] [[package]]
name = "cfg-expr" name = "cfg-expr"
version = "0.15.8" version = "0.15.8"
...@@ -585,6 +614,17 @@ dependencies = [ ...@@ -585,6 +614,17 @@ dependencies = [
"windows-targets 0.52.6", "windows-targets 0.52.6",
] ]
[[package]]
name = "clang-sys"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
dependencies = [
"glob",
"libc",
"libloading",
]
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.31" version = "4.5.31"
...@@ -625,6 +665,15 @@ version = "0.7.4" ...@@ -625,6 +665,15 @@ version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6"
[[package]]
name = "cmake"
version = "0.1.54"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0"
dependencies = [
"cc",
]
[[package]] [[package]]
name = "colorchoice" name = "colorchoice"
version = "1.0.3" version = "1.0.3"
...@@ -1352,6 +1401,12 @@ version = "0.31.1" ...@@ -1352,6 +1401,12 @@ version = "0.31.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
[[package]]
name = "glob"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
[[package]] [[package]]
name = "h2" name = "h2"
version = "0.3.26" version = "0.3.26"
...@@ -1940,6 +1995,16 @@ version = "0.2.169" ...@@ -1940,6 +1995,16 @@ version = "0.2.169"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a"
[[package]]
name = "libloading"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
dependencies = [
"cfg-if 1.0.0",
"windows-targets 0.52.6",
]
[[package]] [[package]]
name = "libredox" name = "libredox"
version = "0.1.3" version = "0.1.3"
...@@ -2854,7 +2919,7 @@ dependencies = [ ...@@ -2854,7 +2919,7 @@ dependencies = [
"pin-project-lite", "pin-project-lite",
"quinn-proto", "quinn-proto",
"quinn-udp", "quinn-udp",
"rustc-hash", "rustc-hash 2.1.1",
"rustls", "rustls",
"socket2", "socket2",
"thiserror 2.0.11", "thiserror 2.0.11",
...@@ -2872,7 +2937,7 @@ dependencies = [ ...@@ -2872,7 +2937,7 @@ dependencies = [
"getrandom 0.2.15", "getrandom 0.2.15",
"rand", "rand",
"ring", "ring",
"rustc-hash", "rustc-hash 2.1.1",
"rustls", "rustls",
"rustls-pki-types", "rustls-pki-types",
"slab", "slab",
...@@ -3152,6 +3217,12 @@ version = "0.1.24" ...@@ -3152,6 +3217,12 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]] [[package]]
name = "rustc-hash" name = "rustc-hash"
version = "2.1.1" version = "2.1.1"
...@@ -4170,10 +4241,12 @@ dependencies = [ ...@@ -4170,10 +4241,12 @@ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
"axum 0.8.1", "axum 0.8.1",
"bindgen",
"blake3", "blake3",
"bs62", "bs62",
"bytes", "bytes",
"chrono", "chrono",
"cmake",
"derive_builder", "derive_builder",
"either", "either",
"erased-serde", "erased-serde",
...@@ -4191,6 +4264,7 @@ dependencies = [ ...@@ -4191,6 +4264,7 @@ dependencies = [
"serde", "serde",
"serde-pickle", "serde-pickle",
"serde_json", "serde_json",
"serde_repr",
"strum", "strum",
"thiserror 2.0.11", "thiserror 2.0.11",
"tokenizers", "tokenizers",
......
...@@ -392,6 +392,26 @@ version = "1.6.0" ...@@ -392,6 +392,26 @@ version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
[[package]]
name = "bindgen"
version = "0.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f49d8fed880d473ea71efb9bf597651e77201bdd4893efe54c9e5d65ae04ce6f"
dependencies = [
"bitflags 2.8.0",
"cexpr",
"clang-sys",
"itertools 0.12.1",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash 1.1.0",
"shlex",
"syn 2.0.98",
]
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "1.3.2" version = "1.3.2"
...@@ -493,6 +513,15 @@ dependencies = [ ...@@ -493,6 +513,15 @@ dependencies = [
"shlex", "shlex",
] ]
[[package]]
name = "cexpr"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
dependencies = [
"nom",
]
[[package]] [[package]]
name = "cfg-expr" name = "cfg-expr"
version = "0.15.8" version = "0.15.8"
...@@ -536,6 +565,17 @@ dependencies = [ ...@@ -536,6 +565,17 @@ dependencies = [
"windows-targets 0.52.6", "windows-targets 0.52.6",
] ]
[[package]]
name = "clang-sys"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
dependencies = [
"glob",
"libc",
"libloading",
]
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.30" version = "4.5.30"
...@@ -576,6 +616,15 @@ version = "0.7.4" ...@@ -576,6 +616,15 @@ version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6"
[[package]]
name = "cmake"
version = "0.1.54"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0"
dependencies = [
"cc",
]
[[package]] [[package]]
name = "colorchoice" name = "colorchoice"
version = "1.0.3" version = "1.0.3"
...@@ -1274,6 +1323,12 @@ version = "0.31.1" ...@@ -1274,6 +1323,12 @@ version = "0.31.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
[[package]]
name = "glob"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
[[package]] [[package]]
name = "h2" name = "h2"
version = "0.4.8" version = "0.4.8"
...@@ -1803,6 +1858,16 @@ version = "0.2.169" ...@@ -1803,6 +1858,16 @@ version = "0.2.169"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a"
[[package]]
name = "libloading"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
dependencies = [
"cfg-if 1.0.0",
"windows-targets 0.52.6",
]
[[package]] [[package]]
name = "libredox" name = "libredox"
version = "0.1.3" version = "0.1.3"
...@@ -2674,7 +2739,7 @@ dependencies = [ ...@@ -2674,7 +2739,7 @@ dependencies = [
"pin-project-lite", "pin-project-lite",
"quinn-proto", "quinn-proto",
"quinn-udp", "quinn-udp",
"rustc-hash", "rustc-hash 2.1.1",
"rustls", "rustls",
"socket2", "socket2",
"thiserror 2.0.11", "thiserror 2.0.11",
...@@ -2692,7 +2757,7 @@ dependencies = [ ...@@ -2692,7 +2757,7 @@ dependencies = [
"getrandom 0.2.15", "getrandom 0.2.15",
"rand", "rand",
"ring", "ring",
"rustc-hash", "rustc-hash 2.1.1",
"rustls", "rustls",
"rustls-pki-types", "rustls-pki-types",
"slab", "slab",
...@@ -2932,6 +2997,12 @@ version = "0.1.24" ...@@ -2932,6 +2997,12 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]] [[package]]
name = "rustc-hash" name = "rustc-hash"
version = "2.1.1" version = "2.1.1"
...@@ -3938,10 +4009,12 @@ dependencies = [ ...@@ -3938,10 +4009,12 @@ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
"axum 0.8.1", "axum 0.8.1",
"bindgen",
"blake3", "blake3",
"bs62", "bs62",
"bytes", "bytes",
"chrono", "chrono",
"cmake",
"derive_builder", "derive_builder",
"either", "either",
"erased-serde", "erased-serde",
...@@ -3959,6 +4032,7 @@ dependencies = [ ...@@ -3959,6 +4032,7 @@ dependencies = [
"serde", "serde",
"serde-pickle", "serde-pickle",
"serde_json", "serde_json",
"serde_repr",
"strum", "strum",
"thiserror 2.0.11", "thiserror 2.0.11",
"tokenizers", "tokenizers",
......
...@@ -452,7 +452,7 @@ dependencies = [ ...@@ -452,7 +452,7 @@ dependencies = [
"bitflags 2.8.0", "bitflags 2.8.0",
"cexpr", "cexpr",
"clang-sys", "clang-sys",
"itertools 0.11.0", "itertools 0.12.1",
"lazy_static", "lazy_static",
"lazycell", "lazycell",
"log", "log",
...@@ -466,6 +466,26 @@ dependencies = [ ...@@ -466,6 +466,26 @@ dependencies = [
"which", "which",
] ]
[[package]]
name = "bindgen"
version = "0.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f49d8fed880d473ea71efb9bf597651e77201bdd4893efe54c9e5d65ae04ce6f"
dependencies = [
"bitflags 2.8.0",
"cexpr",
"clang-sys",
"itertools 0.13.0",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash 1.1.0",
"shlex",
"syn 2.0.98",
]
[[package]] [[package]]
name = "bindgen_cuda" name = "bindgen_cuda"
version = "0.1.5" version = "0.1.5"
...@@ -2674,7 +2694,7 @@ version = "0.1.102" ...@@ -2674,7 +2694,7 @@ version = "0.1.102"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0522f9894e22dd988dd2e34222bda7acba53a0dcce744ca6d8ddce905ba33a4e" checksum = "0522f9894e22dd988dd2e34222bda7acba53a0dcce744ca6d8ddce905ba33a4e"
dependencies = [ dependencies = [
"bindgen", "bindgen 0.69.5",
"cc", "cc",
"cmake", "cmake",
"find_cuda_helper", "find_cuda_helper",
...@@ -5584,10 +5604,12 @@ dependencies = [ ...@@ -5584,10 +5604,12 @@ dependencies = [
"async-trait", "async-trait",
"async_zmq", "async_zmq",
"axum 0.8.1", "axum 0.8.1",
"bindgen 0.70.1",
"blake3", "blake3",
"bs62", "bs62",
"bytes", "bytes",
"chrono", "chrono",
"cmake",
"derive_builder", "derive_builder",
"either", "either",
"erased-serde", "erased-serde",
...@@ -5607,6 +5629,7 @@ dependencies = [ ...@@ -5607,6 +5629,7 @@ dependencies = [
"serde", "serde",
"serde-pickle", "serde-pickle",
"serde_json", "serde_json",
"serde_repr",
"strum 0.27.1", "strum 0.27.1",
"thiserror 2.0.11", "thiserror 2.0.11",
"tokenizers", "tokenizers",
......
...@@ -25,6 +25,7 @@ license = "Apache-2.0" ...@@ -25,6 +25,7 @@ license = "Apache-2.0"
mistralrs = ["triton-distributed-llm/mistralrs"] mistralrs = ["triton-distributed-llm/mistralrs"]
sglang = ["triton-distributed-llm/sglang", "dep:netlink-packet-route", "dep:rtnetlink"] sglang = ["triton-distributed-llm/sglang", "dep:netlink-packet-route", "dep:rtnetlink"]
llamacpp = ["triton-distributed-llm/llamacpp"] llamacpp = ["triton-distributed-llm/llamacpp"]
trtllm = ["triton-distributed-llm/trtllm"]
cuda = ["triton-distributed-llm/cuda"] cuda = ["triton-distributed-llm/cuda"]
metal = ["triton-distributed-llm/metal"] metal = ["triton-distributed-llm/metal"]
vllm = ["triton-distributed-llm/vllm"] vllm = ["triton-distributed-llm/vllm"]
......
...@@ -139,3 +139,41 @@ Run (still inside that virtualenv) - GGUF: ...@@ -139,3 +139,41 @@ Run (still inside that virtualenv) - GGUF:
./target/release/tio in=http out=vllm --model-path ~/llm_models/Llama-3.2-3B-Instruct-Q6_K.gguf --model-config ~/llm_models/Llama-3.2-3B-Instruct/ ./target/release/tio in=http out=vllm --model-path ~/llm_models/Llama-3.2-3B-Instruct-Q6_K.gguf --model-config ~/llm_models/Llama-3.2-3B-Instruct/
``` ```
## trtllm
TensorRT-LLM. Requires `clang` and `libclang-dev`.
Build:
```
cargo build --release --features trtllm
```
Run:
```
tio in=text out=trtllm --model-path /app/trtllm_engine/ --model-config ~/llm_models/Llama-3.2-3B-Instruct/
```
Note that TRT-LLM uses it's own `.engine` format for weights. Repo models must be converted like so:
+ Get the build container
```
docker run --gpus all -it nvcr.io/nvidian/nemo-llm/trtllm-engine-builder:0.2.0 bash
```
+ Fetch the model and convert
```
mkdir /tmp/model
huggingface-cli download meta-llama/Llama-3.2-3B-Instruct --local-dir /tmp/model
python convert_checkpoint.py --model_dir /tmp/model/ --output_dir ./converted --dtype [float16|bfloat16|whatever you want] --tp_size X --pp_size Y
trtllm-build --checkpoint_dir ./converted --output_dir ./final/trtllm_engine --use_paged_context_fmha enable --gemm_plugin auto
```
The `--model-path` you give to `tio` must contain the `config.json` (TRT-LLM's , not the model's) and `rank0.engine` (plus other ranks if relevant).
+ Execute
TRT-LLM is a C++ library that must have been previously built and installed. It needs a lot of memory to compile. Gitlab builds a container you can try:
```
sudo docker run --gpus all -it -v /home/graham:/outside-home gitlab-master.nvidia.com:5005/dl/ai-services/libraries/rust/nim-nvllm/tensorrt_llm_runtime:85fa4a6f
```
Copy the trt-llm engine, the model's `.json` files (for the model deployment card) and the `nio` binary built for the correct glibc (container is Ubuntu 22.04 currently) into that container.
...@@ -71,7 +71,7 @@ pub struct Flags { ...@@ -71,7 +71,7 @@ pub struct Flags {
#[arg(long)] #[arg(long)]
pub model_config: Option<PathBuf>, pub model_config: Option<PathBuf>,
/// sglang only /// sglang and trtllm only
/// ///
/// How many GPUs to use at once, total across all nodes. /// How many GPUs to use at once, total across all nodes.
/// This must divide by num_nodes, and each node must use the same number of GPUs. /// This must divide by num_nodes, and each node must use the same number of GPUs.
...@@ -377,6 +377,26 @@ pub async fn run( ...@@ -377,6 +377,26 @@ pub async fn run(
card: Box::new(card), card: Box::new(card),
} }
} }
#[cfg(feature = "trtllm")]
Output::TrtLLM => {
use triton_distributed_llm::engines::trtllm;
let Some(model_path) = model_path else {
anyhow::bail!("out=trtllm requires flag --model-path=<full-path-to-model-dir>");
};
if !model_path.is_dir() {
anyhow::bail!(
"--model-path should point at a directory containing `.engine` files."
);
}
// Safety: Earlier we build maybe_card from model_path, which we checked right above
let card = maybe_card.clone().unwrap();
let engine = trtllm::make_engine(model_path.display(), flags.tensor_parallel_size)?;
EngineConfig::StaticCore {
service_name: card.service_name.clone(),
engine,
card: Box::new(card),
}
}
}; };
match in_opt { match in_opt {
......
...@@ -41,7 +41,7 @@ const DEFAULT_OUT: Output = Output::EchoFull; ...@@ -41,7 +41,7 @@ const DEFAULT_OUT: Output = Output::EchoFull;
const ZMQ_SOCKET_PREFIX: &str = "tio"; const ZMQ_SOCKET_PREFIX: &str = "tio";
const USAGE: &str = "USAGE: tio in=[http|text|tdr://<path>|none] out=[mistralrs|sglang|llamacpp|vllm|echo_full|echo_core] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--num-nodes=1] [--node-rank=0] [--dist-init-addr=127.0.0.1:9876] [--base-gpu-id=0]"; const USAGE: &str = "USAGE: tio in=[http|text|tdr://<path>|none] out=[mistralrs|sglang|llamacpp|vllm|trtllm|echo_full|echo_core] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--num-nodes=1] [--node-rank=0] [--dist-init-addr=127.0.0.1:9876] [--base-gpu-id=0]";
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
logging::init(); logging::init();
......
...@@ -87,6 +87,10 @@ pub enum Output { ...@@ -87,6 +87,10 @@ pub enum Output {
#[cfg(feature = "vllm")] #[cfg(feature = "vllm")]
/// Run inference using vllm's engine /// Run inference using vllm's engine
Vllm, Vllm,
#[cfg(feature = "trtllm")]
/// Run inference using trtllm
TrtLLM,
} }
impl TryFrom<&str> for Output { impl TryFrom<&str> for Output {
...@@ -106,6 +110,9 @@ impl TryFrom<&str> for Output { ...@@ -106,6 +110,9 @@ impl TryFrom<&str> for Output {
#[cfg(feature = "vllm")] #[cfg(feature = "vllm")]
"vllm" => Ok(Output::Vllm), "vllm" => Ok(Output::Vllm),
#[cfg(feature = "trtllm")]
"trtllm" => Ok(Output::TrtLLM),
"echo_full" => Ok(Output::EchoFull), "echo_full" => Ok(Output::EchoFull),
"echo_core" => Ok(Output::EchoCore), "echo_core" => Ok(Output::EchoCore),
...@@ -134,6 +141,9 @@ impl fmt::Display for Output { ...@@ -134,6 +141,9 @@ impl fmt::Display for Output {
#[cfg(feature = "vllm")] #[cfg(feature = "vllm")]
Output::Vllm => "vllm", Output::Vllm => "vllm",
#[cfg(feature = "trtllm")]
Output::TrtLLM => "trtllm",
Output::EchoFull => "echo_full", Output::EchoFull => "echo_full",
Output::EchoCore => "echo_core", Output::EchoCore => "echo_core",
......
...@@ -392,6 +392,26 @@ version = "1.6.0" ...@@ -392,6 +392,26 @@ version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
[[package]]
name = "bindgen"
version = "0.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f49d8fed880d473ea71efb9bf597651e77201bdd4893efe54c9e5d65ae04ce6f"
dependencies = [
"bitflags 2.7.0",
"cexpr",
"clang-sys",
"itertools 0.13.0",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash 1.1.0",
"shlex",
"syn 2.0.96",
]
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "1.3.2" version = "1.3.2"
...@@ -505,6 +525,15 @@ dependencies = [ ...@@ -505,6 +525,15 @@ dependencies = [
"shlex", "shlex",
] ]
[[package]]
name = "cexpr"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
dependencies = [
"nom",
]
[[package]] [[package]]
name = "cfg-expr" name = "cfg-expr"
version = "0.15.8" version = "0.15.8"
...@@ -548,6 +577,17 @@ dependencies = [ ...@@ -548,6 +577,17 @@ dependencies = [
"windows-targets 0.52.6", "windows-targets 0.52.6",
] ]
[[package]]
name = "clang-sys"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
dependencies = [
"glob",
"libc",
"libloading",
]
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.26" version = "4.5.26"
...@@ -575,6 +615,15 @@ version = "0.7.4" ...@@ -575,6 +615,15 @@ version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6"
[[package]]
name = "cmake"
version = "0.1.54"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0"
dependencies = [
"cc",
]
[[package]] [[package]]
name = "colorchoice" name = "colorchoice"
version = "1.0.3" version = "1.0.3"
...@@ -1273,6 +1322,12 @@ version = "0.31.1" ...@@ -1273,6 +1322,12 @@ version = "0.31.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
[[package]]
name = "glob"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
[[package]] [[package]]
name = "h2" name = "h2"
version = "0.4.7" version = "0.4.7"
...@@ -1798,6 +1853,16 @@ version = "0.2.169" ...@@ -1798,6 +1853,16 @@ version = "0.2.169"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a"
[[package]]
name = "libloading"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
dependencies = [
"cfg-if 1.0.0",
"windows-targets 0.52.6",
]
[[package]] [[package]]
name = "libredox" name = "libredox"
version = "0.1.3" version = "0.1.3"
...@@ -2656,7 +2721,7 @@ dependencies = [ ...@@ -2656,7 +2721,7 @@ dependencies = [
"pin-project-lite", "pin-project-lite",
"quinn-proto", "quinn-proto",
"quinn-udp", "quinn-udp",
"rustc-hash", "rustc-hash 2.1.1",
"rustls", "rustls",
"socket2", "socket2",
"thiserror 2.0.11", "thiserror 2.0.11",
...@@ -2674,7 +2739,7 @@ dependencies = [ ...@@ -2674,7 +2739,7 @@ dependencies = [
"getrandom 0.2.15", "getrandom 0.2.15",
"rand", "rand",
"ring", "ring",
"rustc-hash", "rustc-hash 2.1.1",
"rustls", "rustls",
"rustls-pki-types", "rustls-pki-types",
"slab", "slab",
...@@ -2915,6 +2980,12 @@ version = "0.1.24" ...@@ -2915,6 +2980,12 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]] [[package]]
name = "rustc-hash" name = "rustc-hash"
version = "2.1.1" version = "2.1.1"
...@@ -3893,10 +3964,12 @@ dependencies = [ ...@@ -3893,10 +3964,12 @@ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
"axum 0.8.1", "axum 0.8.1",
"bindgen",
"blake3", "blake3",
"bs62", "bs62",
"bytes", "bytes",
"chrono", "chrono",
"cmake",
"derive_builder", "derive_builder",
"either", "either",
"erased-serde", "erased-serde",
...@@ -3914,6 +3987,7 @@ dependencies = [ ...@@ -3914,6 +3987,7 @@ dependencies = [
"serde", "serde",
"serde-pickle", "serde-pickle",
"serde_json", "serde_json",
"serde_repr",
"strum", "strum",
"thiserror 2.0.11", "thiserror 2.0.11",
"tokenizers", "tokenizers",
......
---
# Refer to the following link for the explanation of each params:
# http://releases.llvm.org/12.0.0/tools/clang/docs/ClangFormatStyleOptions.html
Language: Cpp
# BasedOnStyle: Google
AccessModifierOffset: -2
AlignAfterOpenBracket: Align
AlignConsecutiveAssignments: true
AlignConsecutiveDeclarations: false
AlignEscapedNewlines: Left
AlignOperands: true
AlignTrailingComments: true
AllowAllArgumentsOnNextLine: false
AllowAllConstructorInitializersOnNextLine: false
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortBlocksOnASingleLine: true
AllowShortCaseLabelsOnASingleLine: false # Allows placing breakpoint
AllowShortFunctionsOnASingleLine: Empty
AllowShortIfStatementsOnASingleLine: Never
AllowShortLambdasOnASingleLine: Empty
AllowShortLoopsOnASingleLine: false
# This is deprecated
AlwaysBreakAfterDefinitionReturnType: None
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: Yes
BinPackArguments: false
BinPackParameters: false
BraceWrapping:
AfterClass: true
AfterControlStatement: true
AfterEnum: true
AfterFunction: true
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: true
AfterUnion: true
AfterExternBlock: false
BeforeCatch: false
BeforeElse: true
IndentBraces: false
# disabling the below splits, else, they'll just add to the vertical length of source files!
SplitEmptyFunction: false
SplitEmptyRecord: false
SplitEmptyNamespace: false
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Custom
BreakBeforeInheritanceComma: false
BreakInheritanceList: BeforeColon
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
BreakConstructorInitializers: AfterColon
BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: true
ColumnLimit: 120
CommentPragmas: '^ IWYU pragma:'
CompactNamespaces: false
# Kept the below 2 to be the same as `IndentWidth` to keep everything uniform
ConstructorInitializerIndentWidth: 2
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DisableFormat: false
ExperimentalAutoDetectBinPacking: false
FixNamespaceComments: true
ForEachMacros:
- foreach
- Q_FOREACH
- BOOST_FOREACH
IncludeBlocks: Regroup
IncludeCategories:
# The order of the groups is
# 0 - Main include file for .cpp
# 1 - source relative files `#include "./some_header.hpp"` (Grouped with 2)
# 2 - source relative files starting with internal/public `#include "internal/some_header.hpp"`
# 3 - Python MRC public API files `#include "pymrc/mrc_header.hpp"`
# 4 - MRC public API files `#include "mrc/mrc_header.hpp"`
# 5 - NVRPC public API files `#include "nvrpc/some_header.hpp"`
# 6 - External installed libraries `#include <external_lib/some_header.hpp>`
# 7 - System includes `#include <string>`
# First match any Python MRC public API headers with quotes
- Regex: '^"pymrc\/.*\.(h|hpp)"'
Priority: 3
# Next match any MRC public API headers with quotes
- Regex: '^"mrc\/.*\.(h|hpp)"'
Priority: 4
# Next match public NVRPC headers with quotes
- Regex: '^<nvrpc\/.*\.(h|hpp)>'
Priority: 5
# Next find any headers in internal or public
- Regex: '^"(internal|public)\/.*\.(h|hpp)"'
Priority: 2
# Any other quoted includes need to be with internal/public but on top (Thats why this group is last)
- Regex: '^".*\.(h|hpp)"'
Priority: 1
# Last is system includes which dont have a '/' like <string> or <mutex>
- Regex: '<([a-z_])+>'
Priority: 7
# Finally, put all 3rd party includes before the system includes
- Regex: '^<.*'
Priority: 6
# IncludeIsMainSourceRegex: '$?'
IncludeIsMainRegex: '([-_](test|unittest))?$'
IndentCaseBlocks: false
IndentCaseLabels: false
IndentPPDirectives: BeforeHash
IndentWidth: 4
IndentWrappedFunctionNames: false
JavaScriptQuotes: Leave
JavaScriptWrapImports: true
KeepEmptyLinesAtTheStartOfBlocks: false
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBinPackProtocolList: Never
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: true
PackConstructorInitializers: CurrentLine
PenaltyBreakAssignment: 80
PenaltyBreakBeforeFirstCallParameter: 0
PenaltyBreakComment: 10
PenaltyBreakFirstLessLess: 10
PenaltyBreakString: 0
PenaltyBreakTemplateDeclaration: 10
PenaltyExcessCharacter: 100
PenaltyReturnTypeOnItsOwnLine: 600
PointerAlignment: Left
RawStringFormats:
- Language: Cpp
Delimiters:
- cc
- CC
- cpp
- Cpp
- CPP
- 'c++'
- 'C++'
CanonicalDelimiter: ''
- Language: TextProto
Delimiters:
- pb
- PB
- proto
- PROTO
EnclosingFunctions:
- EqualsProto
- EquivToProto
- PARSE_PARTIAL_TEXT_PROTO
- PARSE_TEST_PROTO
- PARSE_TEXT_PROTO
- ParseTextOrDie
- ParseTextProtoOrDie
CanonicalDelimiter: ''
BasedOnStyle: google
# Enabling comment reflow causes doxygen comments to be messed up in their formats!
ReflowComments: true
SortIncludes: true
SortUsingDeclarations: true
SpaceAfterCStyleCast: false
SpaceAfterTemplateKeyword: true
SpaceBeforeAssignmentOperators: true
SpaceBeforeCpp11BracedList: false
SpaceBeforeCtorInitializerColon: true
SpaceBeforeInheritanceColon: true
SpaceBeforeParens: ControlStatements
SpaceBeforeRangeBasedForLoopColon: true
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 2
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: c++20
StatementMacros:
- Q_UNUSED
- QT_REQUIRE_VERSION
# Be consistent with indent-width, even for people who use tab for indentation!
TabWidth: 4
UseTab: Never
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
cmake_minimum_required(VERSION 3.17)
project(
nvllm
VERSION 0.1.0.0
LANGUAGES CXX
)
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules/set_ifndef.cmake)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED true)
set(CMAKE_POSTION_INDEPENDENT_CODE ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
option(USE_STUBS "Build with stub implementations instead of real CUDA code" OFF)
if (USE_STUBS)
add_definitions(-DUSE_STUBS)
set(SOURCE_FILES
src/nvllm_trt.cpp
src/engine_stub/engine.cpp
)
add_library(tensorrt_llm SHARED src/engine_stub/tensorrt_llm.cpp)
else()
#SET(TRTLLM_SRC_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../submodules/tensorrt_llm" CACHE STRING "TRTLLM_SRC_DIR: /../../submodules/tensorrt_llm")
SET(TRTLLM_LIB_DIR "/usr/local/lib" CACHE STRING "TRTLLM_LIB_DIR: /usr/local/lib")
#include(${TRTLLM_SRC_DIR}/cpp/cmake/modules/find_library_create_target.cmake)
set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib;/opt/hpcx/ompi/lib:/usr/local/cuda/lib64:/usr/local/tensorrt/targets/x86_64-linux-gnu/lib:/src/tensorrt_llm/cpp/build/tensorrt_llm/plugins")
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
include(FetchContent)
FetchContent_Declare(
json
GIT_REPOSITORY https://github.com/nlohmann/json.git
GIT_TAG v3.11.2
)
FetchContent_Declare(
spdlog
GIT_REPOSITORY https://github.com/gabime/spdlog.git
GIT_TAG v1.15.0
)
# Make nlohmann/json available
FetchContent_MakeAvailable(json)
FetchContent_MakeAvailable(spdlog)
set_property(TARGET spdlog PROPERTY POSITION_INDEPENDENT_CODE ON)
add_library(tensorrt_llm SHARED IMPORTED)
set_target_properties(
tensorrt_llm
PROPERTIES
IMPORTED_LOCATION "${TRTLLM_LIB_DIR}/libtensorrt_llm.so"
)
add_library(nvinfer_plugin_tensorrt_llm SHARED IMPORTED)
set_target_properties(
nvinfer_plugin_tensorrt_llm
PROPERTIES
IMPORTED_LOCATION "${TRTLLM_LIB_DIR}/libnvinfer_plugin_tensorrt_llm.so"
)
add_library(xxhash STATIC IMPORTED)
set_target_properties(
xxhash
PROPERTIES
IMPORTED_LOCATION "/usr/lib/x86_64-linux-gnu/libxxhash.a"
)
set(SOURCE_FILES
src/nvllm_trt.cpp
src/engine_trt/engine.cpp
src/engine_trt/request.cpp
src/engine_trt/response.cpp
src/engine_trt/config.cpp
src/engine_trt/kv_event.cpp
src/engine_trt/stats.cpp
${PROTO_SRCS} ${PROTO_HDRS}
# ... other source files ...
)
endif()
function(set_library_target_properties target)
target_include_directories(
${target}
PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include/>
$<INSTALL_INTERFACE:include/>
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src
${CMAKE_BINARY_DIR}
/usr/local/cuda-12.6/targets/x86_64-linux/include
/usr/local/tensorrt/include/
)
target_compile_features(${target} PRIVATE cxx_std_17)
set_target_properties(${target} PROPERTIES OUTPUT_NAME nvllm_trt)
target_compile_options(
${target}
PRIVATE
$<$<OR:$<CXX_COMPILER_ID:Clang>,$<CXX_COMPILER_ID:AppleClang>,$<CXX_COMPILER_ID:GNU>>:
-Wall
-Wextra
-Wno-unused-parameter
-Wno-type-limits>
-Wno-deprecated-declarations
$<$<CXX_COMPILER_ID:MSVC>:/Wall
/D_WIN32_WINNT=0x0A00
/EHsc>)
if (USE_STUBS)
else()
target_link_libraries(
${target}
PRIVATE tensorrt_llm
${Protobuf_LIBRARIES}
xxhash
# ${MPI_LIBRARIES}
# ${CUDA_LIBRARIES}
# nvinfer
nvinfer_plugin_tensorrt_llm
nlohmann_json::nlohmann_json
spdlog::spdlog
)
endif()
# target_link_options(${target} PRIVATE "-static")
target_link_libraries(${target} PUBLIC
)
endfunction()
add_library(nvllm_trt SHARED ${SOURCE_FILES})
set_library_target_properties(nvllm_trt)
include(CMakePackageConfigHelpers)
configure_package_config_file(
${CMAKE_CURRENT_SOURCE_DIR}/nvllmConfig.cmake.in
${CMAKE_CURRENT_BINARY_DIR}/nvllmConfig.cmake
INSTALL_DESTINATION lib/cmake/nvllm
)
write_basic_package_version_file(
"nvllmConfigVersion.cmake"
VERSION ${PROJECT_VERSION}
COMPATIBILITY AnyNewerVersion
)
# Installation rules
install(TARGETS nvllm_trt
EXPORT nvllmConfig # This should match the name used in configure_package_config_file
LIBRARY DESTINATION lib
ARCHIVE DESTINATION lib
RUNTIME DESTINATION bin
INCLUDES DESTINATION include
)
# Install the nvllmConfig.cmake and nvllmConfigVersion.cmake files
install(FILES
${CMAKE_CURRENT_BINARY_DIR}/nvllmConfig.cmake # Corrected the file name
${CMAKE_CURRENT_BINARY_DIR}/nvllmConfigVersion.cmake
DESTINATION lib/cmake/nvllm
)
# # Install config.h
# install(FILES "${PROJECT_BINARY_DIR}/config.h"
# DESTINATION include/nvidia/nvllm)
# Install header files
install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/
DESTINATION include)
#ifndef __NVIDIA_NVLLM_TRT_C_API__
#define __NVIDIA_NVLLM_TRT_C_API__
#ifdef __cplusplus
extern "C" {
#endif
#include <stdint.h>
typedef enum
{
NVLLM_TRT_ENGINE_SUCCESS = 0, // No error
NVLLM_TRT_ENGINE_INVALID_REQUEST = 1, // Invalid request error
NVLLM_TRT_ENGINE_SHUTDOWN_REQUIRED = 2, // Shutdown and join required before destroying
NVLLM_TRT_ENGINE_SHUTDOWN_IN_PROGRESS = 3, // Shutdown in progress
} nvllm_trt_engine_error_t;
// struct nvllm_trt_engine {};
// Forward declaration of the C++ class
typedef struct nvllm_trt_engine nvllm_trt_engine;
typedef nvllm_trt_engine* nvllm_trt_engine_t;
typedef uint64_t request_id_t;
typedef uint64_t client_id_t;
// Set the MPI Communicator for the TensorRT LLM Engine
// This function should be called before creating the engine
int nvllm_trt_mpi_session_set_communicator(void* world_comm_ptr);
// Functions to interact with nvllm_trt_engine_s
nvllm_trt_engine_t nvllm_trt_engine_create(const char* config_proto);
// Create a nvLLM TRT Engine from an instance of the engine
// This requires the raw engine pointer to be an instantiated object at the exact same
// commit version as the version of TRTLLM used to build the nvLLM C API.
// This is a workaround to enable the Triton TensorRT LLM backend to use nvLLM.
nvllm_trt_engine_t nvllm_trt_engine_unsafe_create_from_executor(void* engine);
// Source: Enqueue a streaming request via a json message to the request queue
request_id_t nvllm_trt_engine_enqueue_request(nvllm_trt_engine_t engine, client_id_t client_id, const char* req_proto);
// Sink: Pull off streaming responses from the response queue
char* nvllm_trt_engine_await_responses(nvllm_trt_engine_t engine);
// Sink: Pull off KvEvents from the event queue
char* nvllm_trt_engine_await_kv_events(nvllm_trt_engine_t engine);
// Get basic iteration stats
char* nvllm_trt_engine_await_iter_stats(nvllm_trt_engine_t engine);
// Free the memory allocated by nvllm_trt_engine_await_responses
void nvllm_trt_engine_free_responses(char* responses);
// Sink: Pull off streaming responses from the response queue
void nvllm_trt_engine_cancel_request(nvllm_trt_engine_t engine, uint64_t request_id);
// Initiate the shutdown sequence
void nvllm_trt_engine_shutdown(nvllm_trt_engine_t engine);
// // Await for the shutdown to complete; shutdown will be requested if not already requested
// void nvllm_trt_engine_join(nvllm_trt_engine_t engine);
// Destroy the engine
int nvllm_trt_engine_destroy(nvllm_trt_engine_t engine);
// Returns true (non-zero) once the engine has started pulling requests
// There is currently no stopping, so once an engine has started,
// it will always return true, even when complete.
// This call does not block; the user should use some backoff strategy
// to poll for detecting the start of the engine.
int nvllm_trt_engine_is_ready(nvllm_trt_engine_t engine);
// Returns true (non-zero) once the engine has stopped pulling requests
int nvllm_trt_engine_has_completed(nvllm_trt_engine_t engine);
// // Returns the major version number of the trtllm library
// int trtllm_version_major();
// // Returns the minor version number of the trtllm library
// int trtllm_version_minor();
// // Returns the patch version number of the trtllm library
// int trtllm_version_patch();
#ifdef __cplusplus
}
#endif
#endif // __NVIDIA_NVLLM_TRT_C_API__
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <optional>
#include <string>
namespace nvidia::nvllm::trt {
class StreamingEngine
{
public:
StreamingEngine(const std::string& config_proto);
StreamingEngine(void* engine);
~StreamingEngine();
// accepts a string of a serialized proto::Request
// forms the internal request object and enqueues it
// returns a request_id provided by the engine; this must be used to cancel the request
// accepts a client_id which can be use to identify the response
uint64_t enqueue_request(uint64_t client_id, const std::string& json_request);
// awaits the presence of a response
// converts the internal format to a json and returns the string
std::string await_responses();
// awaits the presence of a kv events
std::optional<std::string> await_kv_events();
// Awaits iteration stats
std::optional<std::string> await_iter_stats();
// cancel request
void cancel_request(uint64_t request_id);
// called to start the shutdown sequence
void shutdown();
// returns true once the engine as started pulling requests
// there is currently no stopping, so once an engine has_started,
// it will always return true, even when complete
bool is_ready() const;
// returns true if the StreamingEngine has been both shutdown and joined
bool has_completed() const;
private:
class Impl;
std::unique_ptr<Impl> m_impl;
};
} // namespace nvidia::nvllm::trt
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Public API for the StreamingEngine class
#include "nvidia/nvllm/nvllm_trt.h"
// Internal Private Implementation
#include "api/engine.hpp"
#include <optional>
extern "C" {
bool initTrtLlmPlugins(void* logger, char const* libNamespace);
}
namespace nvidia::nvllm::trt {
class StreamingEngine::Impl
{
public:
Impl(const std::string& config_proto);
Impl(void* engine);
~Impl() = default;
uint64_t enqueue_request(uint64_t client_id, const std::string& req_proto)
{
std::abort();
return 911;
}
void cancel_request(uint64_t request_id) {}
std::string await_responses()
{
std::abort();
return {};
}
std::optional<std::string> await_kv_events()
{
std::abort();
return std::nullopt;
}
std::optional<std::string> await_iter_stats()
{
std::abort();
return std::nullopt;
}
void shutdown()
{
std::abort();
}
bool is_ready() const
{
std::abort();
return false;
}
bool has_completed() const
{
std::abort();
return false;
}
};
// Private Engine Impl
StreamingEngine::Impl::Impl(const std::string& config_proto)
{
initTrtLlmPlugins(nullptr, nullptr);
}
StreamingEngine::Impl::Impl(void* engine)
{
initTrtLlmPlugins(nullptr, nullptr);
}
// Public Engine Impl
StreamingEngine::StreamingEngine(const std::string& config_proto) :
m_impl{std::make_unique<Impl>(config_proto)} {} // namespace nvidia::nvllm::trt
StreamingEngine::StreamingEngine(void* engine) :
m_impl{std::make_unique<Impl>(engine)} {} // namespace nvidia::nvllm::trt
StreamingEngine::~StreamingEngine()
{
if (!m_impl->has_completed())
{
m_impl->shutdown();
}
}
uint64_t StreamingEngine::enqueue_request(uint64_t client_id, const std::string& req_proto)
{
return m_impl->enqueue_request(client_id, req_proto);
}
std::string StreamingEngine::await_responses()
{
return m_impl->await_responses();
}
std::optional<std::string> StreamingEngine::await_kv_events()
{
return m_impl->await_kv_events();
}
std::optional<std::string> StreamingEngine::await_iter_stats()
{
return m_impl->await_iter_stats();
}
void StreamingEngine::cancel_request(uint64_t request_id)
{
m_impl->cancel_request(request_id);
}
void StreamingEngine::shutdown()
{
m_impl->shutdown();
}
bool StreamingEngine::is_ready() const
{
return m_impl->is_ready();
}
bool StreamingEngine::has_completed() const
{
return m_impl->has_completed();
}
} // namespace nvidia::nvllm::trt
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern "C" {
bool initTrtLlmPlugins(void* logger, char const* libNamespace) {}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "engine_trt/config.hpp"
#include <nlohmann/json.hpp>
#include <spdlog/spdlog.h>
#include <cstdint>
#include <optional>
#include <string>
#include <vector>
using json = nlohmann::json;
namespace ex = tensorrt_llm::executor;
namespace nvidia::nvllm::trt {
struct ExecutorConfig
{
std::string model_path;
std::string log_level;
std::optional<bool> enable_chunked_context;
std::optional<bool> normalize_log_probs;
std::optional<uint32_t> iter_stats_max_iterations;
};
// Custom to_json function
inline void to_json(json& j, const ExecutorConfig& e)
{
j = json{{"model_path", e.model_path}, {"log_level", e.log_level}};
if (e.enable_chunked_context)
{
j["enable_chunked_context"] = e.enable_chunked_context.value();
}
if (e.normalize_log_probs)
{
j["normalize_log_probs"] = e.normalize_log_probs.value();
}
if (e.iter_stats_max_iterations)
{
j["iter_stats_max_iterations"] = e.iter_stats_max_iterations.value();
}
}
// Custom from_json function
inline void from_json(const json& j, ExecutorConfig& e)
{
j.at("model_path").get_to(e.model_path);
j.at("log_level").get_to(e.log_level);
if (j.contains("enable_chunked_context"))
{
e.enable_chunked_context = j.at("enable_chunked_context").get<bool>();
}
else
{
e.enable_chunked_context = std::nullopt;
}
if (j.contains("normalize_log_probs"))
{
e.normalize_log_probs = j.at("normalize_log_probs").get<bool>();
}
else
{
e.normalize_log_probs = std::nullopt;
}
if (j.contains("iter_stats_max_iterations"))
{
e.iter_stats_max_iterations = j.at("iter_stats_max_iterations").get<uint32_t>();
}
else
{
e.iter_stats_max_iterations = std::nullopt;
}
}
Config deserialize_config(const std::string& config_json)
{
auto config_in = json::parse(config_json).get<ExecutorConfig>();
auto model_path = config_in.model_path;
auto log_level = config_in.log_level;
auto config = ex::ExecutorConfig();
// todo - expose max num tokens
// todo - expose from engine block reuse
if (config_in.enable_chunked_context)
{
spdlog::info("Enable chunked context: {}", config_in.enable_chunked_context.value() ? "true" : "false");
config.setEnableChunkedContext(config_in.enable_chunked_context.value());
}
return {model_path, log_level, config};
}
} // namespace nvidia::nvllm::trt
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tensorrt_llm/executor/executor.h"
namespace nvidia::nvllm::trt {
struct Config
{
std::string model_path;
std::string log_level;
tensorrt_llm::executor::ExecutorConfig config;
};
Config deserialize_config(const std::string& request);
} // namespace nvidia::nvllm::trt
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Public API for the StreamingEngine class
#include "nvidia/nvllm/nvllm_trt.h"
// Internal Private Implementation
#include "api/engine.hpp"
#include "engine_trt/config.hpp"
#include "engine_trt/kv_event.hpp"
#include "engine_trt/request.hpp"
#include "engine_trt/response.hpp"
#include "engine_trt/stats.hpp"
// TensorRT LLM Executor
#include "NvInfer.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/plugins/api/tllmPlugin.h"
// Third-party
#include <spdlog/sinks/stdout_color_sinks.h>
#include <spdlog/spdlog.h>
namespace ex = tensorrt_llm::executor;
namespace nvidia::nvllm::trt {
/// Customize the logger for TensorRT LLM using a module-specific spdlog logger
class TRTLogger : public nvinfer1::ILogger
{
public:
TRTLogger(std::shared_ptr<spdlog::logger> logger) : m_logger(logger) {}
void log(nvinfer1::ILogger::Severity severity, const char* msg) noexcept override
{
if (severity <= nvinfer1::ILogger::Severity::kERROR)
{
m_logger->error("{}", msg);
}
else if (severity == nvinfer1::ILogger::Severity::kWARNING)
{
m_logger->warn("{}", msg);
}
else
{
m_logger->info("{}", msg);
}
}
private:
std::shared_ptr<spdlog::logger> m_logger;
};
class StreamingEngine::Impl
{
public:
Impl(const std::string& config_proto);
Impl(void* engine);
~Impl() = default;
/// Enqueues a request to the executor
/// In this opionionated implementation, [`client_id`] is required to be unique
uint64_t enqueue_request(uint64_t client_id, const std::string& req_json)
{
spdlog::trace("enqueue_request - client_id: {}", client_id);
auto request = deserialize_request(req_json);
request.setClientId(client_id);
auto request_id = m_executor->enqueueRequest(request);
spdlog::trace("request_id: {} with client_id {} was enqueued", request_id, client_id);
return request_id;
}
/// Cancellation is by [`request_id`], not [`client_id`]
void cancel_request(uint64_t request_id)
{
spdlog::trace("cancel_request: {}", request_id);
m_executor->cancelRequest(request_id);
}
/// Issues a shutdown request to the executor. This is a blocking call.
/// We protect it with a mutex to ensure that it is only called once.
void shutdown()
{
std::lock_guard<std::mutex> lock(m_mutex);
if (m_has_completed)
{
return;
}
m_executor->shutdown();
m_has_completed = true;
}
/// Returns true if the executor is ready to accept requests.
/// Not sure of TensorRT LLM's behavior when the executor is shutdown, so we
/// return false if the executor has completed.
bool is_ready() const
{
std::lock_guard<std::mutex> lock(m_mutex);
if (m_has_completed)
{
return false;
}
return m_executor->canEnqueueRequests();
}
/// Returns true if the executor has completed.
bool has_completed() const
{
std::lock_guard<std::mutex> lock(m_mutex);
return m_has_completed;
}
/// Awaits on the executor for responses. This is a blocking call.
/// TensorRT LLM will throw an exception if a thread is blocked on the calls and the
/// executor is shutdown.
std::string await_responses()
{
spdlog::trace("blocking on await_responses");
std::deque<ex::Response> responses;
bool shutdown = false;
try
{
auto v_responses = m_executor->awaitResponses();
spdlog::trace("received {} responses", v_responses.size());
for (auto& response : v_responses)
{
responses.push_back(std::move(response));
}
} catch (const std::exception& e)
{
spdlog::trace("Exception caught awaiting responses; shutting down");
shutdown = true;
}
return serialize_responses(std::move(responses), shutdown);
}
/// Awaits for KV events. This is a blocking call with a timeout of 250ms.
/// The current implementation will not throw an exception if the executor is shutdown,
/// so we need timeout the call to ensure that calling thread can shutdown properly.
std::optional<std::string> await_kv_events()
{
if (m_kv_cache_event_manager == nullptr)
{
auto manager = m_executor->getKVCacheEventManager();
if (manager)
{
m_kv_cache_event_manager = *manager;
}
}
if (m_kv_cache_event_manager == nullptr)
{
return std::nullopt;
}
try
{
auto events = m_kv_cache_event_manager->getLatestEvents({std::chrono::milliseconds(250)});
if (!events.empty())
{
spdlog::trace("received {} on kv_events", events.size());
}
return {serialize_kv_events(std::move(events), false)};
} catch (const std::exception& e)
{
spdlog::trace("Exception caught awaiting kv events; shutting down");
return {serialize_kv_events({}, true)};
}
}
// Awaits iteration stats
std::optional<std::string> await_iter_stats()
{
auto iter_stats = m_executor->getLatestIterationStats();
return serialize_iter_stats(iter_stats);
}
private:
std::unique_ptr<ex::Executor> m_executor;
std::shared_ptr<ex::KVCacheEventManager> m_kv_cache_event_manager = nullptr;
bool m_has_completed = false;
mutable std::mutex m_mutex;
};
// Private Engine Impl
StreamingEngine::Impl::Impl(void* engine)
{
auto nvllm_logger = spdlog::stdout_color_mt("nvllm");
spdlog::set_default_logger(nvllm_logger);
spdlog::info("Instantiating nvLLM from raw TensorRT LLM Executor pointer");
m_executor.reset(reinterpret_cast<ex::Executor*>(engine));
}
StreamingEngine::Impl::Impl(const std::string& config_json)
{
auto nvllm_logger = spdlog::stdout_color_mt("nvllm");
auto trtllm_logger = spdlog::stdout_color_mt("trtllm");
spdlog::set_default_logger(nvllm_logger);
auto config = deserialize_config(config_json);
if (config.log_level == "error")
{
spdlog::set_level(spdlog::level::err);
nvllm_logger->set_level(spdlog::level::err);
trtllm_logger->set_level(spdlog::level::err);
}
else if (config.log_level == "warn")
{
spdlog::set_level(spdlog::level::warn);
nvllm_logger->set_level(spdlog::level::warn);
trtllm_logger->set_level(spdlog::level::warn);
}
else if (config.log_level == "info")
{
spdlog::set_level(spdlog::level::info);
nvllm_logger->set_level(spdlog::level::info);
trtllm_logger->set_level(spdlog::level::info);
}
else if (config.log_level == "debug")
{
spdlog::set_level(spdlog::level::debug);
nvllm_logger->set_level(spdlog::level::debug);
trtllm_logger->set_level(spdlog::level::debug);
}
else if (config.log_level == "trace")
{
spdlog::set_level(spdlog::level::trace);
nvllm_logger->set_level(spdlog::level::trace);
trtllm_logger->set_level(spdlog::level::trace);
}
else
{
spdlog::set_level(spdlog::level::err);
nvllm_logger->set_level(spdlog::level::err);
trtllm_logger->set_level(spdlog::level::err);
}
TRTLogger* trtLogger = new TRTLogger(trtllm_logger);
initTrtLlmPlugins(trtLogger);
auto kv_config = config.config.getKvCacheConfig();
spdlog::info("Enabled block reuse: true");
kv_config.setEnableBlockReuse(true);
kv_config.setEventBufferMaxSize(65536);
config.config.setKvCacheConfig(kv_config);
m_executor = std::make_unique<ex::Executor>(config.model_path, ex::ModelType::kDECODER_ONLY, config.config);
}
// Public Engine Impl
StreamingEngine::StreamingEngine(const std::string& config_proto) :
m_impl{std::make_unique<Impl>(config_proto)} {} // namespace nvidia::nvllm::trt
StreamingEngine::StreamingEngine(void* engine) :
m_impl{std::make_unique<Impl>(engine)} {} // namespace nvidia::nvllm::trt
StreamingEngine::~StreamingEngine()
{
if (!m_impl->has_completed())
{
m_impl->shutdown();
}
}
uint64_t StreamingEngine::enqueue_request(uint64_t client_id, const std::string& req_proto)
{
return m_impl->enqueue_request(client_id, req_proto);
}
std::string StreamingEngine::await_responses()
{
return m_impl->await_responses();
}
std::optional<std::string> StreamingEngine::await_kv_events()
{
return m_impl->await_kv_events();
}
std::optional<std::string> StreamingEngine::await_iter_stats()
{
return m_impl->await_iter_stats();
}
void StreamingEngine::cancel_request(uint64_t request_id)
{
m_impl->cancel_request(request_id);
}
void StreamingEngine::shutdown()
{
m_impl->shutdown();
}
bool StreamingEngine::is_ready() const
{
return m_impl->is_ready();
}
bool StreamingEngine::has_completed() const
{
return m_impl->has_completed();
}
} // namespace nvidia::nvllm::trt
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "engine_trt/kv_event.hpp"
#include <nlohmann/json.hpp>
#include <spdlog/spdlog.h>
#include <xxhash.h>
#include <optional>
#include <string>
#include <vector>
using json = nlohmann::json;
namespace ex = tensorrt_llm::executor;
namespace tensorrt_llm::executor {
// Serialization for KVCacheRemovedData
void to_json(json& j, const KVCacheRemovedData& data)
{
j = json{{"block_hashes", data.blockHashes}};
}
void from_json(const json& j, KVCacheRemovedData& data)
{
j.at("block_hashes").get_to(data.blockHashes);
}
} // namespace tensorrt_llm::executor
namespace nvidia::nvllm::trt {
using IdType = ex::IdType;
using TokenIdType = ex::TokenIdType;
struct KVCacheStoredBlockData
{
KVCacheStoredBlockData() = default;
KVCacheStoredBlockData(const ex::KVCacheStoredBlockData& data)
{
std::vector<TokenIdType> tokens;
for (auto& token : data.tokens)
{
tokens.push_back(token.tokenId);
}
auto size = tokens.size() * sizeof(TokenIdType);
auto hash = XXH3_64bits_withSeed(tokens.data(), size, 1337);
this->block_hash = data.blockHash;
this->tokens_hash = hash;
this->lora_id = data.loraId;
}
/// @brief The hash of the block
IdType block_hash;
/// @brief The tokens in the block
IdType tokens_hash;
/// @brief The Lora ID of the block
IdType lora_id;
};
// Serialization for KVCacheStoredBlockData
void to_json(json& j, const KVCacheStoredBlockData& data)
{
j = json{
{"block_hash", data.block_hash},
{"tokens_hash", data.tokens_hash},
{"lora_id", data.lora_id},
};
}
void from_json(const json& j, KVCacheStoredBlockData& data)
{
j.at("block_hash").get_to(data.block_hash);
j.at("tokens_hash").get_to(data.tokens_hash);
j.at("lora_id").get_to(data.lora_id);
}
struct KVCacheStoredData
{
KVCacheStoredData() = default;
KVCacheStoredData(ex::KVCacheStoredData&& data) : parent_hash(std::move(data.parentHash))
{
for (auto& block : data.blocks)
{
blocks.emplace_back(block);
}
}
/// @brief The parent of this sequence of stored blocks
std::optional<IdType> parent_hash;
/// @brief A sequence of blocks. The parent of block `i` is block `i-1`
std::vector<KVCacheStoredBlockData> blocks;
};
using KVCacheRemovedData = ex::KVCacheRemovedData;
// Serialization for KVCacheStoredData
void to_json(json& j, const KVCacheStoredData& data)
{
j = json{{"blocks", data.blocks}};
if (data.parent_hash)
{
j["parent_hash"] = data.parent_hash.value();
}
}
void from_json(const json& j, KVCacheStoredData& data)
{
j.at("blocks").get_to(data.blocks);
if (j.contains("parent_hash"))
{
data.parent_hash = j.at("parent_hash").get<IdType>();
}
}
struct KVCacheEventData
{
KVCacheEventData() = default;
explicit KVCacheEventData(ex::KVCacheEventData&& data)
{
if (std::holds_alternative<ex::KVCacheStoredData>(data))
{
stored = KVCacheStoredData(std::move(std::get<ex::KVCacheStoredData>(data)));
}
else if (std::holds_alternative<ex::KVCacheRemovedData>(data))
{
removed = std::move(std::get<ex::KVCacheRemovedData>(data));
}
}
std::optional<KVCacheStoredData> stored;
std::optional<KVCacheRemovedData> removed;
};
// Serialization for KVCacheEventData
void to_json(json& j, const KVCacheEventData& data)
{
if (data.stored)
{
j["stored"] = data.stored.value();
}
else if (data.removed)
{
j["removed"] = data.removed.value();
}
}
void from_json(const json& j, KVCacheEventData& data)
{
if (j.contains("stored"))
{
data.stored = {j.at("stored").get<KVCacheStoredData>()};
}
else if (j.contains("removed"))
{
data.removed = {j.at("removed").get<KVCacheRemovedData>()};
}
}
struct KVCacheEvent
{
KVCacheEvent(IdType eventId, KVCacheEventData data);
KVCacheEvent(ex::KVCacheEvent&& event) : event_id(std::move(event.eventId)), data(std::move(event.data)) {}
/// @brief The unique id of this event
IdType event_id;
/// @brief The data corresponding to this event
KVCacheEventData data;
};
inline void to_json(json& j, const KVCacheEvent& event)
{
j = json{{"event_id", event.event_id}, {"data", event.data}};
}
inline void from_json(const json& j, KVCacheEvent& event)
{
j.at("event_id").get_to(event.event_id);
j.at("data").get_to(event.data);
}
struct KVCacheEvents
{
std::vector<KVCacheEvent> events;
bool shutdown;
};
inline void to_json(json& j, const KVCacheEvents& events)
{
j = json{{"events", events.events}, {"shutdown", events.shutdown}};
}
// inline void from_json(const json& j, KVCacheEvents& events)
// {
// j.at("events").get_to(events.events);
// j.at("shutdown").get_to(events.shutdown);
// }
std::string serialize_kv_events(std::deque<tensorrt_llm::executor::KVCacheEvent> events_in, bool shutdown)
{
std::vector<KVCacheEvent> events_out;
while (!events_in.empty())
{
auto event = events_in.front();
events_in.pop_front();
if (std::holds_alternative<ex::KVCacheCreatedData>(event.data) ||
std::holds_alternative<ex::KVCacheUpdatedData>(event.data))
{
continue;
}
events_out.emplace_back(std::move(event));
}
KVCacheEvents events{std::move(events_out), shutdown};
return json(events).dump();
}
} // namespace nvidia::nvllm::trt
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment