Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
4f6f63cd
Commit
4f6f63cd
authored
Feb 24, 2025
by
Biswa Panda
Committed by
GitHub
Feb 24, 2025
Browse files
feat: add rust based tokenizer
parent
53163693
Changes
50
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3423 additions
and
140 deletions
+3423
-140
.pre-commit-config.yaml
.pre-commit-config.yaml
+1
-1
applications/llm/count/Cargo.lock
applications/llm/count/Cargo.lock
+690
-44
applications/llm/tio/Cargo.lock
applications/llm/tio/Cargo.lock
+134
-34
container/Dockerfile
container/Dockerfile
+4
-1
container/Dockerfile.vllm
container/Dockerfile.vllm
+4
-1
examples/rust/Cargo.lock
examples/rust/Cargo.lock
+683
-43
llm/rust/Cargo.lock
llm/rust/Cargo.lock
+221
-9
llm/rust/triton-llm/Cargo.toml
llm/rust/triton-llm/Cargo.toml
+32
-1
llm/rust/triton-llm/src/backend.rs
llm/rust/triton-llm/src/backend.rs
+529
-0
llm/rust/triton-llm/src/lib.rs
llm/rust/triton-llm/src/lib.rs
+3
-0
llm/rust/triton-llm/src/model_card/create.rs
llm/rust/triton-llm/src/model_card/create.rs
+12
-6
llm/rust/triton-llm/src/preprocessor.rs
llm/rust/triton-llm/src/preprocessor.rs
+356
-0
llm/rust/triton-llm/src/preprocessor/prompt.rs
llm/rust/triton-llm/src/preprocessor/prompt.rs
+61
-0
llm/rust/triton-llm/src/preprocessor/prompt/template.rs
llm/rust/triton-llm/src/preprocessor/prompt/template.rs
+94
-0
llm/rust/triton-llm/src/preprocessor/prompt/template/context.rs
...st/triton-llm/src/preprocessor/prompt/template/context.rs
+56
-0
llm/rust/triton-llm/src/preprocessor/prompt/template/formatters.rs
...triton-llm/src/preprocessor/prompt/template/formatters.rs
+107
-0
llm/rust/triton-llm/src/preprocessor/prompt/template/oai.rs
llm/rust/triton-llm/src/preprocessor/prompt/template/oai.rs
+111
-0
llm/rust/triton-llm/src/preprocessor/prompt/template/tokcfg.rs
...ust/triton-llm/src/preprocessor/prompt/template/tokcfg.rs
+159
-0
llm/rust/triton-llm/src/preprocessor/tools.rs
llm/rust/triton-llm/src/preprocessor/tools.rs
+115
-0
llm/rust/triton-llm/src/preprocessor/tools/request.rs
llm/rust/triton-llm/src/preprocessor/tools/request.rs
+51
-0
No files found.
.pre-commit-config.yaml
View file @
4f6f63cd
...
...
@@ -43,7 +43,7 @@ repos:
-
id
:
codespell
additional_dependencies
:
[
tomli
]
args
:
[
"
--toml"
,
"
pyproject.toml"
]
exclude
:
(?x)^(.*stemmer.*|.*stop_words.*|^CHANGELOG.md$|.*tests/data/
replays.
*)
exclude
:
(?x)^(.*stemmer.*|.*stop_words.*|^CHANGELOG.md$|.*tests/data/*)
# More details about these pre-commit hooks here:
# https://pre-commit.com/hooks.html
-
repo
:
https://github.com/pre-commit/pre-commit-hooks
...
...
applications/llm/count/Cargo.lock
View file @
4f6f63cd
This diff is collapsed.
Click to expand it.
applications/llm/tio/Cargo.lock
View file @
4f6f63cd
...
...
@@ -129,9 +129,9 @@ dependencies = [
[[package]]
name = "anyhow"
version = "1.0.9
5
"
version = "1.0.9
6
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b0
4"
checksum = "
6b964d184e89d9b6b67dd2715bc8e74cf3107fb2b529990c90cf517326150bf
4"
[[package]]
name = "arbitrary"
...
...
@@ -478,6 +478,17 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3eeab4423108c5d7c744f4d234de88d18d636100093ae04caf4825134b9c3a32"
[[package]]
name = "bs62"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "afdf0f7a8a430954ae6cba49be4af8fa3a4bee15d9210aa8a82aee6abca531bd"
dependencies = [
"lazy_static",
"num-bigint",
"num-traits",
]
[[package]]
name = "bumpalo"
version = "3.17.0"
...
...
@@ -627,9 +638,9 @@ dependencies = [
[[package]]
name = "cc"
version = "1.2.1
4
"
version = "1.2.1
5
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
0c3d1b2e905a3a7b00a6141adb0e4c0bb941d11caf55349d863942a1cc44e3c9
"
checksum = "
c736e259eea577f443d5c86c304f9f4ae0295c43f3ba05c21f1d66b5f06001af
"
dependencies = [
"jobserver",
"libc",
...
...
@@ -946,9 +957,9 @@ dependencies = [
[[package]]
name = "cudarc"
version = "0.13.
6
"
version = "0.13.
7
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
cf16a4eaf3c5c36c9a7e4096bf8611cd963aa71d6b67162d538d7ea13befeeea
"
checksum = "
4e29ce3bfa797c1183053ceb496316203ef561c183941c3c181500d9ade6daf4
"
dependencies = [
"half",
"libloading",
...
...
@@ -1354,6 +1365,16 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
[[package]]
name = "erased-serde"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24e2389d65ab4fab27dc2a5de7b191e1f6617d1f1c8855c0dc569c94a4cbb18d"
dependencies = [
"serde",
"typeid",
]
[[package]]
name = "errno"
version = "0.3.10"
...
...
@@ -2482,7 +2503,7 @@ dependencies = [
"rustc-hash",
"serde",
"serde_json",
"toktrie",
"toktrie
0.1.0
",
"url",
]
...
...
@@ -2510,9 +2531,9 @@ dependencies = [
[[package]]
name = "log"
version = "0.4.2
5
"
version = "0.4.2
6
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15
e29f
ff3f6c077fd9cd9f
"
checksum = "
30bde2b3dc3671ae49d8
e2
e
9f
044c7c005836e7a023ee57cffa25ab82764bb9e
"
[[package]]
name = "lrtable"
...
...
@@ -2589,6 +2610,12 @@ dependencies = [
"stable_deref_trait",
]
[[package]]
name = "memo-map"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b"
[[package]]
name = "metal"
version = "0.27.0"
...
...
@@ -2616,6 +2643,8 @@ version = "2.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cff7b8df5e85e30b87c2b0b3f58ba3a87b68e133738bf512a7713769326dbca9"
dependencies = [
"memo-map",
"self_cell",
"serde",
"serde_json",
]
...
...
@@ -2638,9 +2667,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]]
name = "miniz_oxide"
version = "0.8.
4
"
version = "0.8.
5
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
b3b1c9bd4fe1f0f8b387f6eb9eb3b4a1aa26185e5750efb9140301703f62cd1b
"
checksum = "
8e3e04debbb59698c15bacbb6d93584a8c0ca9cc3213cb423d31f760d8843ce5
"
dependencies = [
"adler2",
"simd-adler32",
...
...
@@ -2783,12 +2812,12 @@ dependencies = [
"tokenizers",
"tokio",
"tokio-rayon",
"toktrie_hf_tokenizers",
"toktrie_hf_tokenizers
0.1.0
",
"toml",
"tqdm",
"tracing",
"tracing-subscriber",
"uuid 1.1
3.2
",
"uuid 1.1
4.0
",
"variantly",
"vob",
]
...
...
@@ -2872,9 +2901,9 @@ checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03"
[[package]]
name = "native-tls"
version = "0.2.1
3
"
version = "0.2.1
4
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
0dab59f8e050d5df8e4dd87d9206fb6f65a483e20ac9
fda36
5ade4fab353196c
"
checksum = "
87de3442987e9dbec73158d5c715e7ad9072
fda
9
36
bb03d19d7fa10e00520f0e
"
dependencies = [
"libc",
"log",
...
...
@@ -3083,6 +3112,16 @@ dependencies = [
"rand",
]
[[package]]
name = "num-bigint"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9"
dependencies = [
"num-integer",
"num-traits",
]
[[package]]
name = "num-complex"
version = "0.4.6"
...
...
@@ -3099,6 +3138,15 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
...
...
@@ -3715,9 +3763,9 @@ checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430"
[[package]]
name = "redox_syscall"
version = "0.5.
8
"
version = "0.5.
9
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834
"
checksum = "
82b568323e98e49e2a0899dcee453dd679fae22d69adf9b11dd508d1549b7e2f
"
dependencies = [
"bitflags 2.8.0",
]
...
...
@@ -3857,9 +3905,9 @@ dependencies = [
[[package]]
name = "ring"
version = "0.17.
9
"
version = "0.17.
11
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
e75ec5e92c4d8aede845126adc388046234541629e76029599ed35a003c7ed24
"
checksum = "
da5349ae27d3887ca812fb375b45a4fbb36d8d12d2df394968cd86e35683fe73
"
dependencies = [
"cc",
"cfg-if 1.0.0",
...
...
@@ -4093,11 +4141,20 @@ dependencies = [
"libc",
]
[[package]]
name = "self_cell"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2fdfc24bc566f839a2da4c4295b82db7d25a24253867d5c64355abb5799bdbe"
[[package]]
name = "semver"
version = "1.0.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f79dfe2d285b0488816f30e700a7438c5a73d816b5b7d3ac72fbc48b0d185e03"
dependencies = [
"serde",
]
[[package]]
name = "seq-macro"
...
...
@@ -4107,18 +4164,18 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4"
[[package]]
name = "serde"
version = "1.0.21
7
"
version = "1.0.21
8
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c7
0"
checksum = "
e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df6
0"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.21
7
"
version = "1.0.21
8
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0
"
checksum = "
f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b
"
dependencies = [
"proc-macro2",
"quote",
...
...
@@ -4138,9 +4195,9 @@ dependencies = [
[[package]]
name = "serde_json"
version = "1.0.13
8
"
version = "1.0.13
9
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949
"
checksum = "
44f86c3acccc9c65b153fe1b85a3be07fe5515274ec9f0653b4a0875731c72a6
"
dependencies = [
"indexmap 2.7.1",
"itoa",
...
...
@@ -4841,6 +4898,19 @@ dependencies = [
"serde_json",
]
[[package]]
name = "toktrie"
version = "0.6.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f9c32a81c3faff7dde7909b471a5c39970e684b7fb88994a0fe4d9a3fb3a2b1"
dependencies = [
"anyhow",
"bytemuck",
"bytemuck_derive",
"serde",
"serde_json",
]
[[package]]
name = "toktrie_hf_tokenizers"
version = "0.1.0"
...
...
@@ -4852,7 +4922,21 @@ dependencies = [
"serde",
"serde_json",
"tokenizers",
"toktrie",
"toktrie 0.1.0",
]
[[package]]
name = "toktrie_hf_tokenizers"
version = "0.6.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e43a051aef243334de35fe3ba70060e11aa373eeb193cc98cce37383d312002"
dependencies = [
"anyhow",
"log",
"serde",
"serde_json",
"tokenizers",
"toktrie 0.6.28",
]
[[package]]
...
...
@@ -5107,7 +5191,7 @@ dependencies = [
"tokio-util",
"tracing",
"tracing-subscriber",
"uuid 1.1
3.2
",
"uuid 1.1
4.0
",
"validator",
"xxhash-rust",
]
...
...
@@ -5121,25 +5205,35 @@ dependencies = [
"async-trait",
"axum 0.8.1",
"blake3",
"bs62",
"bytes",
"chrono",
"derive_builder",
"either",
"erased-serde",
"futures",
"galil-seiferas",
"indexmap 2.7.1",
"itertools 0.14.0",
"minijinja",
"minijinja-contrib",
"mistralrs",
"prometheus",
"regex",
"semver",
"serde",
"serde_json",
"thiserror 2.0.11",
"tokenizers",
"tokio",
"tokio-stream",
"tokio-util",
"toktrie 0.6.28",
"toktrie_hf_tokenizers 0.6.28",
"tracing",
"triton-distributed",
"unicode-segmentation",
"uuid 1.1
3.2
",
"uuid 1.1
4.0
",
"validator",
"xxhash-rust",
]
...
...
@@ -5161,6 +5255,12 @@ dependencies = [
"tokio",
]
[[package]]
name = "typeid"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e13db2e0ccd5e14a544e8a246ba2312cd25223f616442d7f2cb0e3db614236e"
[[package]]
name = "typenum"
version = "1.18.0"
...
...
@@ -5184,9 +5284,9 @@ checksum = "eeba86d422ce181a719445e51872fa30f1f7413b62becb52e95ec91aa262d85c"
[[package]]
name = "unicode-ident"
version = "1.0.1
6
"
version = "1.0.1
7
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034
"
checksum = "
00e2473a93778eb0bad35909dff6a10d28e63f792f16ed15e404fca9d5eeedbe
"
[[package]]
name = "unicode-normalization-alignments"
...
...
@@ -5287,9 +5387,9 @@ dependencies = [
[[package]]
name = "uuid"
version = "1.1
3.2
"
version = "1.1
4.0
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
8c1f41ffb7cf259f1ecc2876861a17e7142e63ead296f671f81f6ae85903e0d6
"
checksum = "
93d59ca99a559661b96bf898d8fce28ed87935fd2bea9f05983c1464dd6c71b1
"
dependencies = [
"getrandom 0.3.1",
"serde",
...
...
@@ -5756,9 +5856,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "winnow"
version = "0.7.
2
"
version = "0.7.
3
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
59690dea168f2198d1a3b0cac23b8063efcd11012f10ae4698f284808c8ef603
"
checksum = "
0e7f4ea97f6f78012141bcdb6a216b2609f0979ada50b20ca5b52dde2eac2bb1
"
dependencies = [
"memchr",
]
...
...
container/Dockerfile
View file @
4f6f63cd
...
...
@@ -25,7 +25,10 @@ USER root
# TODO: separate dev from runtime dependendcies
# Rust build/dev dependencies
RUN
apt-get update
;
apt-get
install
-y
gdb protobuf-compiler
RUN
apt-get update
&&
\
apt-get
install
--no-install-recommends
--yes
gdb protobuf-compiler cmake libssl-dev pkg-config
RUN
curl https://sh.rustup.rs
-sSf
| bash
-s
--
-y
ENV
PATH="/root/.cargo/bin:${PATH}"
...
...
container/Dockerfile.vllm
View file @
4f6f63cd
...
...
@@ -62,7 +62,10 @@ RUN ln -sf /bin/bash /bin/sh
RUN apt update -y && \
apt install -y \
build-essential \
protobuf-compiler && \
protobuf-compiler \
cmake \
libssl-dev \
pkg-config && \
curl https://sh.rustup.rs -sSf | bash -s -- -y
ENV PATH="/root/.cargo/bin:${PATH}"
...
...
examples/rust/Cargo.lock
View file @
4f6f63cd
This diff is collapsed.
Click to expand it.
llm/rust/Cargo.lock
View file @
4f6f63cd
...
...
@@ -492,6 +492,17 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3eeab4423108c5d7c744f4d234de88d18d636100093ae04caf4825134b9c3a32"
[[package]]
name = "bs62"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "afdf0f7a8a430954ae6cba49be4af8fa3a4bee15d9210aa8a82aee6abca531bd"
dependencies = [
"lazy_static",
"num-bigint",
"num-traits",
]
[[package]]
name = "bstr"
version = "1.11.3"
...
...
@@ -758,6 +769,15 @@ version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6"
[[package]]
name = "cmake"
version = "0.1.54"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0"
dependencies = [
"cc",
]
[[package]]
name = "color_quant"
version = "1.1.0"
...
...
@@ -1366,6 +1386,16 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
[[package]]
name = "erased-serde"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24e2389d65ab4fab27dc2a5de7b191e1f6617d1f1c8855c0dc569c94a4cbb18d"
dependencies = [
"serde",
"typeid",
]
[[package]]
name = "errno"
version = "0.3.10"
...
...
@@ -1392,7 +1422,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc0452bcc559431b16f472b7ab86e2f9ccd5f3c2da3795afbd6b773665e047fe"
dependencies = [
"http",
"prost",
"prost
0.13.4
",
"tokio",
"tokio-stream",
"tonic",
...
...
@@ -1950,6 +1980,28 @@ dependencies = [
"ureq",
]
[[package]]
name = "hf-hub"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "112fa2f6ad4ab815b9e1b938b4b1e437032d055e2f92ed10fd6ab2e62d02c6b6"
dependencies = [
"dirs",
"futures",
"http",
"indicatif",
"log",
"native-tls",
"num_cpus",
"rand",
"reqwest",
"serde",
"serde_json",
"thiserror 2.0.11",
"tokio",
"ureq",
]
[[package]]
name = "http"
version = "1.2.0"
...
...
@@ -2341,6 +2393,7 @@ dependencies = [
"pest",
"pest_derive",
"pin-project",
"regex",
"serde",
"similar",
"walkdir",
...
...
@@ -2376,6 +2429,15 @@ version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "itertools"
version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473"
dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.11.0"
...
...
@@ -2403,6 +2465,15 @@ dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "1.0.14"
...
...
@@ -2552,7 +2623,7 @@ dependencies = [
"rustc-hash",
"serde",
"serde_json",
"toktrie",
"toktrie
0.1.0
",
"url",
]
...
...
@@ -2659,6 +2730,12 @@ dependencies = [
"stable_deref_trait",
]
[[package]]
name = "memo-map"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b"
[[package]]
name = "metal"
version = "0.27.0"
...
...
@@ -2686,6 +2763,8 @@ version = "2.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cff7b8df5e85e30b87c2b0b3f58ba3a87b68e133738bf512a7713769326dbca9"
dependencies = [
"memo-map",
"self_cell",
"serde",
"serde_json",
]
...
...
@@ -2853,7 +2932,7 @@ dependencies = [
"tokenizers",
"tokio",
"tokio-rayon",
"toktrie_hf_tokenizers",
"toktrie_hf_tokenizers
0.1.0
",
"toml",
"tqdm",
"tracing",
...
...
@@ -3078,6 +3157,16 @@ dependencies = [
"rand",
]
[[package]]
name = "num-bigint"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9"
dependencies = [
"num-integer",
"num-traits",
]
[[package]]
name = "num-complex"
version = "0.4.6"
...
...
@@ -3094,6 +3183,26 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
[[package]]
name = "num-derive"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "876a53fff98e03a936a674b29568b0e605f06b29372c2489ff4de23f1949743d"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
...
...
@@ -3574,6 +3683,16 @@ dependencies = [
"unarray",
]
[[package]]
name = "prost"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd"
dependencies = [
"bytes",
"prost-derive 0.11.9",
]
[[package]]
name = "prost"
version = "0.13.4"
...
...
@@ -3581,7 +3700,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c0fef6c4230e4ccf618a35c59d7ede15dea37de8427500f50aff708806e42ec"
dependencies = [
"bytes",
"prost-derive",
"prost-derive
0.13.4
",
]
[[package]]
...
...
@@ -3597,13 +3716,26 @@ dependencies = [
"once_cell",
"petgraph",
"prettyplease",
"prost",
"prost
0.13.4
",
"prost-types",
"regex",
"syn 2.0.98",
"tempfile",
]
[[package]]
name = "prost-derive"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5d2d8d10f3c6ded6da8b05b5fb3b8a5082514344d56c9f871412d29b4e075b4"
dependencies = [
"anyhow",
"itertools 0.10.5",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "prost-derive"
version = "0.13.4"
...
...
@@ -3623,7 +3755,7 @@ version = "0.13.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc2f1e56baa61e93533aebc21af4d2134b70f66275e0fcdf3cbe43d77ff7e8fc"
dependencies = [
"prost",
"prost
0.13.4
",
]
[[package]]
...
...
@@ -4260,11 +4392,46 @@ dependencies = [
"libc",
]
[[package]]
name = "self_cell"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2fdfc24bc566f839a2da4c4295b82db7d25a24253867d5c64355abb5799bdbe"
[[package]]
name = "semver"
version = "1.0.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f79dfe2d285b0488816f30e700a7438c5a73d816b5b7d3ac72fbc48b0d185e03"
dependencies = [
"serde",
]
[[package]]
name = "sentencepiece"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ae716e54c860d65df824a5b606b464e8f2acfc4a7fe93b2a1f6b9a173d1fff5"
dependencies = [
"libc",
"num-derive",
"num-traits",
"prost 0.11.9",
"prost-derive 0.11.9",
"sentencepiece-sys",
"thiserror 1.0.69",
]
[[package]]
name = "sentencepiece-sys"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f21c66315e346665798e116d1c21201434715e13dd691f3f33f6276746d0b71f"
dependencies = [
"cc",
"cmake",
"pkg-config",
]
[[package]]
name = "seq-macro"
...
...
@@ -4865,7 +5032,7 @@ dependencies = [
"derive_builder",
"esaxx-rs",
"getrandom 0.2.15",
"hf-hub",
"hf-hub
0.3.2
",
"indicatif",
"itertools 0.12.1",
"lazy_static",
...
...
@@ -5005,6 +5172,19 @@ dependencies = [
"serde_json",
]
[[package]]
name = "toktrie"
version = "0.6.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f9c32a81c3faff7dde7909b471a5c39970e684b7fb88994a0fe4d9a3fb3a2b1"
dependencies = [
"anyhow",
"bytemuck",
"bytemuck_derive",
"serde",
"serde_json",
]
[[package]]
name = "toktrie_hf_tokenizers"
version = "0.1.0"
...
...
@@ -5016,7 +5196,21 @@ dependencies = [
"serde",
"serde_json",
"tokenizers",
"toktrie",
"toktrie 0.1.0",
]
[[package]]
name = "toktrie_hf_tokenizers"
version = "0.6.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e43a051aef243334de35fe3ba70060e11aa373eeb193cc98cce37383d312002"
dependencies = [
"anyhow",
"log",
"serde",
"serde_json",
"tokenizers",
"toktrie 0.6.28",
]
[[package]]
...
...
@@ -5073,7 +5267,7 @@ dependencies = [
"hyper-util",
"percent-encoding",
"pin-project",
"prost",
"prost
0.13.4
",
"socket2",
"tokio",
"tokio-stream",
...
...
@@ -5285,26 +5479,38 @@ dependencies = [
"async-trait",
"axum 0.8.1",
"blake3",
"bs62",
"bytes",
"chrono",
"derive_builder",
"either",
"erased-serde",
"futures",
"galil-seiferas",
"hf-hub 0.4.1",
"indexmap 2.7.1",
"insta",
"itertools 0.14.0",
"minijinja",
"minijinja-contrib",
"mistralrs",
"prometheus",
"proptest",
"regex",
"reqwest",
"rstest",
"semver",
"sentencepiece",
"serde",
"serde_json",
"tempfile",
"thiserror 2.0.11",
"tokenizers",
"tokio",
"tokio-stream",
"tokio-util",
"toktrie 0.6.28",
"toktrie_hf_tokenizers 0.6.28",
"tracing",
"triton-distributed",
"unicode-segmentation",
...
...
@@ -5330,6 +5536,12 @@ dependencies = [
"tokio",
]
[[package]]
name = "typeid"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e13db2e0ccd5e14a544e8a246ba2312cd25223f616442d7f2cb0e3db614236e"
[[package]]
name = "typenum"
version = "1.17.0"
...
...
llm/rust/triton-llm/Cargo.toml
View file @
4f6f63cd
...
...
@@ -25,6 +25,7 @@ homepage.workspace = true
mistralrs
=
["dep:mistralrs"]
metal
=
["mistralrs/metal"]
cuda
=
["mistralrs/cuda"]
sentencepiece
=
["dep:sentencepiece"]
[dependencies]
...
...
@@ -71,9 +72,39 @@ either = { version = "1.13" }
indexmap
=
{
version
=
"2.6"
}
mistralrs
=
{
git
=
"https://github.com/EricLBuehler/mistral.rs.git"
,
rev
=
"5e689c9"
,
optional
=
true
}
# tokenizers
tokenizers
=
{
version
=
"0.21.0"
,
default-features
=
false
,
features
=
[
"onig"
,
"esaxx_fast"
,
]
}
sentencepiece
=
{
version
=
"0.11.2"
,
optional
=
true
}
# backend
galil-seiferas
=
{
version
=
"0.1"
}
toktrie
=
{
version
=
"0.6.28"
}
toktrie_hf_tokenizers
=
{
version
=
"0.6.28"
}
# preprocessor
bs62
=
{
version
=
"0.1"
}
erased-serde
=
{
version
=
"0.4"
}
itertools
=
{
version
=
"0.14.0"
}
minijinja
=
{
version
=
"2.3.1"
,
features
=
["loader"]
}
minijinja-contrib
=
{
version
=
"2.3.1"
,
features
=
["pycompat"]
}
semver
=
{
version
=
"1"
,
features
=
["serde"]
}
[dev-dependencies]
insta
=
{
version
=
"1.41"
,
features
=
[
"glob"
,
"json"
,
"redactions"
]}
proptest
=
"1.5.0"
reqwest
=
{
version
=
"0.12"
,
default-features
=
false
,
features
=
[
"json"
,
"stream"
,
"rustls-tls"
]
}
rstest
=
"0.18.2"
tempfile
=
"3.17.1"
hf-hub
=
"0.4.1"
insta
=
{
version
=
"1.41"
,
features
=
[
"glob"
,
"json"
,
"redactions"
,
"filters"
,
]
}
[profile.dev.package]
insta.opt-level
=
3
\ No newline at end of file
llm/rust/triton-llm/src/backend.rs
0 → 100644
View file @
4f6f63cd
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Backend
//!
//! An [`Backend`] is the final stage of the pipeline. It represents the execution of the LLM
//! on some processing hardware.
//!
//! At minimum, the Backend is split into two components, the [`Backend`] itself and a downstream [`ExecutionContext`].
//!
//! The [`ExecutionContext`] can be thought of as the core driver of the forward pass, whereas the [`Backend`] is the
//! manager of all resources and concurrent tasks surrounding the LLM execution context / forward pass.
//!
//! For almost every known scenario, detokenization and initial post processing must happen in the Backend.
//! Further post-processing can happen in the response stream. One example is the jailing mechanism for partial
//! hidden stop condition matches, which can be handled in the response stream rather than the backend.
use
std
::{
collections
::
HashSet
,
sync
::
Arc
};
use
anyhow
::{
Error
,
Result
};
use
futures
::
stream
::{
self
,
StreamExt
};
use
tracing
as
log
;
use
crate
::
model_card
::
model
::{
ModelDeploymentCard
,
TokenizerKind
};
use
triton_distributed
::{
pipeline
::{
async_trait
,
AsyncEngineContextProvider
,
ManyOut
,
Operator
,
ResponseStream
,
ServerStreamingEngine
,
SingleIn
,
},
protocols
::
annotated
::
Annotated
,
};
use
crate
::
protocols
::{
common
::{
llm_backend
::{
BackendInput
,
BackendOutput
,
FinishReason
,
LLMEngineOutput
},
StopConditions
,
},
TokenIdType
,
};
use
crate
::
tokenizers
::{
DecodeStream
,
HuggingFaceTokenizer
,
Tokenizer
};
use
tokenizers
::
Tokenizer
as
HfTokenizer
;
use
toktrie
::
TokTrie
;
use
toktrie_hf_tokenizers
::
ByteTokenizer
;
/// Represents the output stream from the execution engine
pub
type
ExecutionOutputStream
=
Annotated
<
LLMEngineOutput
>
;
/// Context for executing LLM inference, engine consumes backend input and produces execution output stream
pub
type
ExecutionContext
=
ServerStreamingEngine
<
BackendInput
,
ExecutionOutputStream
>
;
/// Backend handles resource management and orchestrates LLM execution
#[allow(dead_code)]
pub
struct
Backend
{
mdc
:
ModelDeploymentCard
,
pub
tokenizer
:
Tokenizer
,
// Handles token encoding/decoding
tok_trie
:
Arc
<
TokTrie
>
,
// Efficient token lookup structure
eos_token_ids
:
Vec
<
TokenIdType
>
,
// End of sequence token IDs
validate_engine_decode
:
bool
,
// Enable validation of engine decoding
mdcsum
:
String
,
// Model deployment checksum
}
/// Internal state for managing token decoding and stream processing
#[allow(dead_code)]
struct
DecoderUnfoldState
{
stream
:
ManyOut
<
ExecutionOutputStream
>
,
decoder
:
Decoder
,
validate_engine_decode
:
bool
,
mdcsum
:
String
,
}
impl
Backend
{
pub
async
fn
from_mdc
(
mdc
:
ModelDeploymentCard
)
->
Result
<
Arc
<
Self
>>
{
let
info
=
mdc
.model_info
.get_model_info
()
.await
?
;
let
tokenizer
=
match
&
mdc
.tokenizer
{
TokenizerKind
::
HfTokenizerJson
(
file
)
=>
{
HfTokenizer
::
from_file
(
&
file
)
.map_err
(
Error
::
msg
)
?
}
};
let
bt
=
ByteTokenizer
::
from_tokenizer
(
tokenizer
.clone
())
?
;
let
toktrie
=
TokTrie
::
from
(
&
bt
.tokrx_info
(),
&
bt
.token_bytes
());
let
mdcsum
=
mdc
.mdcsum
();
let
tokenizer
=
HuggingFaceTokenizer
::
from_tokenizer
(
tokenizer
);
let
tokenizer
=
Tokenizer
::
from
(
Arc
::
new
(
tokenizer
));
Ok
(
Arc
::
new
(
Self
{
mdc
,
tokenizer
,
tok_trie
:
Arc
::
new
(
toktrie
),
eos_token_ids
:
info
.eos_token_ids
(),
validate_engine_decode
:
false
,
mdcsum
,
}))
}
fn
decoder
(
&
self
,
stream
:
ManyOut
<
ExecutionOutputStream
>
,
stop_conditions
:
StopConditions
,
)
->
DecoderUnfoldState
{
let
decoder
=
Decoder
::
new
(
self
.tokenizer
.decode_stream
(
false
),
stop_conditions
,
self
.mdcsum
.clone
(),
);
DecoderUnfoldState
{
stream
,
decoder
,
validate_engine_decode
:
self
.validate_engine_decode
,
mdcsum
:
self
.mdcsum
.clone
(),
}
}
}
#[async_trait]
impl
Operator
<
SingleIn
<
BackendInput
>
,
ManyOut
<
Annotated
<
BackendOutput
>>
,
SingleIn
<
BackendInput
>
,
ManyOut
<
Annotated
<
LLMEngineOutput
>>
,
>
for
Backend
{
async
fn
generate
(
&
self
,
request
:
SingleIn
<
BackendInput
>
,
next
:
ServerStreamingEngine
<
BackendInput
,
Annotated
<
LLMEngineOutput
>>
,
)
->
Result
<
ManyOut
<
Annotated
<
BackendOutput
>>>
{
// possible use the request
let
mut
stop_conditions
=
request
.stop_conditions
.clone
();
// preprocessor should have set max_tokens
// assert!(stop_conditions.max_tokens.is_some());
if
stop_conditions
.max_tokens
.is_none
()
{
log
::
warn!
(
"max_tokens is not set in stop_conditions; fixme"
);
stop_conditions
.max_tokens
=
Some
(
256
);
}
let
next_stream
=
next
.generate
(
request
)
.await
?
;
let
context
=
next_stream
.context
();
let
state
=
self
.decoder
(
next_stream
,
stop_conditions
);
let
processed_stream
=
stream
::
unfold
(
state
,
|
mut
state
|
async
move
{
match
state
.stream
.next
()
.await
{
Some
(
output
)
=>
{
// move to state.process_output
// handle any error conditions / unwraps here
// events are pass thru
if
output
.is_event
()
||
output
.data
.is_none
()
{
return
Some
((
output
,
state
));
}
// if we have a data field without an event, then we might need to update the data
if
let
Some
(
data
)
=
&
output
.data
{
if
data
.text
.is_some
()
&&
!
state
.validate_engine_decode
{
return
Some
((
output
,
state
));
}
}
let
data
=
output
.data
.as_ref
()
.unwrap
();
let
result
=
state
.decoder
.process_token_ids
(
&
data
.token_ids
)
.unwrap
();
// todo - propagate finish reason details - possibly an annotation
let
finish_reason
=
match
&
result
.stop_trigger
{
Some
(
StopTrigger
::
MaxTokensLimit
)
=>
Some
(
FinishReason
::
Length
),
Some
(
StopTrigger
::
HiddenStopTokenDetected
(
_
))
=>
Some
(
FinishReason
::
Stop
),
Some
(
StopTrigger
::
HiddenStopSequenceDetected
(
_
))
=>
{
Some
(
FinishReason
::
Stop
)
}
None
=>
None
,
};
if
data
.finish_reason
.is_none
()
&&
finish_reason
.is_some
()
{
tracing
::
debug!
(
?
result
.stop_trigger
,
"upstream did not provide a finish reason; issuing a stop_generation request to free resources"
,
);
state
.stream
.context
()
.stop_generating
();
}
let
text
=
result
.text
;
let
tokens
=
result
.tokens
;
if
state
.validate_engine_decode
{
if
data
.finish_reason
!=
finish_reason
{
log
::
warn!
(
"finish reason mismatch: expected {:?}, got {:?}"
,
data
.finish_reason
,
finish_reason
);
}
if
data
.text
.is_some
()
&&
data
.text
!=
text
{
log
::
warn!
(
"text mismatch: expected {:?}, got {:?}"
,
data
.text
,
text
);
}
}
// update output in-place
let
mut
output
=
output
;
let
mut
data
=
output
.data
.take
()
.unwrap
();
data
.finish_reason
=
finish_reason
;
data
.text
=
text
;
data
.tokens
=
Some
(
tokens
);
output
.data
=
Some
(
data
);
Some
((
output
,
state
))
}
None
=>
None
,
}
});
// convert stream of processed Annotated<LLMEngineOutput> to Annotated<BackendOutput>
let
mdcsum
=
self
.mdcsum
.clone
();
let
stream
=
processed_stream
.map
(
move
|
output
|
{
output
.map_data
(|
data
|
{
Ok
(
BackendOutput
{
token_ids
:
data
.token_ids
,
tokens
:
data
.tokens
.unwrap_or_default
(),
text
:
data
.text
,
cum_log_probs
:
data
.cum_log_probs
,
log_probs
:
data
.log_probs
,
finish_reason
:
data
.finish_reason
,
mdcsum
:
mdcsum
.clone
(),
})
})
});
Ok
(
ResponseStream
::
new
(
Box
::
pin
(
stream
),
context
))
}
}
// todo - add visible stop conditions
// visible_stop_ids: HashSet<TokenIdType>,
// visible_stop_sequences: Vec<String>,
/// The [`Decoder`] object could be a member of either the internal LLM engine or part of the
/// postprocessor. If in the postprocessor, should be minimally in the same process or at very minimum
/// on the same physical machine connected by an IPC.
#[allow(dead_code)]
pub
struct
Decoder
{
decode_stream
:
DecodeStream
,
// do not trigger stop conditions until at least this many tokens have been generated
min_tokens
:
u32
,
// maximum number of tokens to generate - the llm engine should enforce this
max_tokens
:
u32
,
// single tokens that if found in the response will trigger a stop condition after the
// minimum number of tokens have been generated
hidden_stop_ids
:
HashSet
<
TokenIdType
>
,
// text sequences that if found in the response will trigger a stop condition after the
// minimum number of tokens have been generated
hidden_stop_sequences
:
Vec
<
String
>
,
// number of generated tokens
generated_tokens
:
u32
,
// content jailed by partial hidden stop matches
jail
:
String
,
// maximum number of bytes for the largest stop sequence
jail_max_bytes
:
usize
,
// the number of bytes currently jailed
jailed_bytes
:
usize
,
// mdcsum
mdcsum
:
String
,
}
#[allow(dead_code)]
#[derive(Debug)]
pub
enum
StopTrigger
{
MaxTokensLimit
,
HiddenStopTokenDetected
(
TokenIdType
),
HiddenStopSequenceDetected
(
String
),
}
impl
StopTrigger
{
pub
fn
should_hide_text
(
&
self
)
->
bool
{
match
self
{
StopTrigger
::
MaxTokensLimit
=>
false
,
StopTrigger
::
HiddenStopTokenDetected
(
_
)
=>
true
,
StopTrigger
::
HiddenStopSequenceDetected
(
_
)
=>
true
,
}
}
}
pub
struct
StepResult
{
pub
token
:
Option
<
String
>
,
pub
stop_trigger
:
Option
<
StopTrigger
>
,
}
impl
StepResult
{
fn
ok
(
token
:
Option
<
String
>
)
->
Self
{
Self
{
token
,
stop_trigger
:
None
,
}
}
fn
with_stop_trigger
(
token
:
Option
<
String
>
,
stop_trigger
:
StopTrigger
)
->
Self
{
Self
{
token
,
stop_trigger
:
Some
(
stop_trigger
),
}
}
}
/// Result of processing a sequence of tokens
pub
struct
SeqResult
{
pub
tokens
:
Vec
<
Option
<
String
>>
,
// Individual decoded tokens
pub
text
:
Option
<
String
>
,
// Combined decoded text
pub
stop_trigger
:
Option
<
StopTrigger
>
,
// Reason for stopping generation, if any
}
#[allow(dead_code)]
impl
Decoder
{
pub
fn
new
(
decode_stream
:
DecodeStream
,
stop_condition
:
StopConditions
,
mdcsum
:
String
,
)
->
Self
{
let
hidden_stop_ids
:
HashSet
<
TokenIdType
>
=
stop_condition
.stop_token_ids_hidden
.unwrap_or_default
()
.iter
()
.copied
()
.collect
();
let
hidden_stop_sequences
:
Vec
<
String
>
=
stop_condition
.stop
.unwrap_or_default
()
.iter
()
.map
(|
x
|
x
.to_string
())
.collect
();
let
jail_max_bytes
=
hidden_stop_sequences
.iter
()
.map
(|
x
|
x
.len
())
.max
()
.unwrap_or
(
0
);
Self
{
decode_stream
,
hidden_stop_ids
,
hidden_stop_sequences
,
//visible_stop_ids: HashSet::new(),
//visible_stop_sequences: Vec::new(),
min_tokens
:
stop_condition
.min_tokens
.unwrap_or
(
0
),
max_tokens
:
stop_condition
.max_tokens
.expect
(
"max_tokens is required"
),
generated_tokens
:
0
,
jail
:
String
::
new
(),
jail_max_bytes
,
jailed_bytes
:
0
,
mdcsum
,
}
}
/// Minimum amount of work to determine if a given generated/decoded sequence should be stopped
/// This method can be called by the inner most loop of the LLM engine or minimally in the same
/// process as the LLM engine.
///
/// In the future, this method may kick off async cpu/tokio tasks and or async cuda tasks to
/// handle logits post-processing and/or other tasks.
pub
fn
step
(
&
mut
self
,
token_id
:
TokenIdType
)
->
Result
<
StepResult
>
{
// increment the generated tokens
self
.generated_tokens
+=
1
;
// decode the token
let
token
=
self
.decode_stream
.step
(
token_id
)
?
;
// stop conditions to not apply until the minimum number of tokens have been generated
if
self
.generated_tokens
<
self
.min_tokens
{
return
Ok
(
StepResult
::
ok
(
token
));
}
// check for hidden stop tokens - eos takes precedence
if
self
.hidden_stop_ids
.contains
(
&
token_id
)
{
return
Ok
(
StepResult
::
with_stop_trigger
(
token
,
StopTrigger
::
HiddenStopTokenDetected
(
token_id
),
));
}
// next check max_tokens limit
if
self
.generated_tokens
>=
self
.max_tokens
{
return
Ok
(
StepResult
::
with_stop_trigger
(
token
,
StopTrigger
::
MaxTokensLimit
,
));
}
// check stop sequences - the jail will always hold at least the largest stop sequence
// if jail_max_bytes is 0, then there are no stop sequences
if
self
.jail_max_bytes
>
0
{
if
let
Some
(
token
)
=
&
token
{
let
pre_append
=
self
.jail
.len
();
log
::
debug!
(
"pre_append: {}"
,
pre_append
);
log
::
debug!
(
"jail: {}"
,
self
.jail
);
self
.jail
.push_str
(
token
);
log
::
debug!
(
"post_append: {}"
,
self
.jail
.len
());
log
::
debug!
(
"jail: {}"
,
self
.jail
);
for
seq
in
&
self
.hidden_stop_sequences
{
log
::
debug!
(
"stop seq: {}"
,
seq
);
if
let
Some
(
offset
)
=
galil_seiferas
::
gs_find
(
self
.jail
.as_bytes
(),
seq
.as_bytes
())
{
log
::
debug!
(
"offset: {}"
,
offset
);
// return only new bytes after pre_append .. offset+seq.len()
// example: seq = "ox", token = "boxes", return "b"
// note: this changes when we start jailing tokens for partial matches
// on the suffix of teh jail with prefixes of the stop sequences
//
// we might have returned a partial match, if so, then offset < pre_append
// in that case, we return the empty string
let
partial_token
=
if
offset
>=
pre_append
{
self
.jail
[
pre_append
..
offset
]
.to_string
()
}
else
{
""
.to_string
()
};
return
Ok
(
StepResult
::
with_stop_trigger
(
Some
(
partial_token
),
StopTrigger
::
HiddenStopSequenceDetected
(
seq
.to_string
()),
));
}
}
if
self
.jail
.len
()
>
self
.jail_max_bytes
{
// truncate the jail
let
drain_len
=
self
.jail
.len
()
-
self
.jail_max_bytes
;
self
.jail
.drain
(
0
..
drain_len
);
}
}
}
Ok
(
StepResult
::
ok
(
token
))
}
pub
fn
process_token_ids
(
&
mut
self
,
token_ids
:
&
[
TokenIdType
])
->
Result
<
SeqResult
>
{
let
mut
text
:
Option
<
String
>
=
None
;
let
mut
tokens
=
Vec
::
new
();
for
token_id
in
token_ids
{
let
StepResult
{
token
,
stop_trigger
,
}
=
self
.step
(
*
token_id
)
?
;
let
hide_text
=
stop_trigger
.as_ref
()
.map
(|
x
|
x
.should_hide_text
())
.unwrap_or
(
false
);
if
!
hide_text
{
if
let
Some
(
token
)
=
&
token
{
text
.get_or_insert_with
(
String
::
new
)
.push_str
(
token
);
}
}
tokens
.push
(
token
);
if
let
Some
(
stop_trigger
)
=
stop_trigger
{
return
Ok
(
SeqResult
{
tokens
,
text
,
stop_trigger
:
Some
(
stop_trigger
),
});
}
}
Ok
(
SeqResult
{
tokens
,
text
,
stop_trigger
:
None
,
})
}
fn
return_token
(
&
self
,
token
:
Option
<
String
>
)
->
StepResult
{
StepResult
{
token
,
stop_trigger
:
None
,
}
}
fn
return_with_stop_trigger
(
&
self
,
token
:
Option
<
String
>
,
stop_trigger
:
StopTrigger
,
)
->
StepResult
{
StepResult
{
token
,
stop_trigger
:
Some
(
stop_trigger
),
}
}
fn
jailed_string
(
&
self
)
->
Option
<
String
>
{
if
self
.jailed_bytes
>
0
{
// get the last jailed_bytes from the jail
Some
(
self
.jail
[
self
.jail
.len
()
-
self
.jailed_bytes
..
]
.to_string
())
}
else
{
None
}
}
}
llm/rust/triton-llm/src/lib.rs
View file @
4f6f63cd
...
...
@@ -18,10 +18,13 @@
//! The `triton-llm` crate is a Rust library that provides a set of traits and types for building
//! distributed LLM inference solutions.
pub
mod
backend
;
pub
mod
common
;
pub
mod
engines
;
pub
mod
http
;
pub
mod
kv_router
;
pub
mod
model_card
;
pub
mod
preprocessor
;
pub
mod
protocols
;
pub
mod
tokenizers
;
pub
mod
types
;
llm/rust/triton-llm/src/model_card/create.rs
View file @
4f6f63cd
...
...
@@ -38,7 +38,10 @@ impl ModelDeploymentCard {
/// - The path doesn't exist or isn't a directory
/// - The path contains invalid Unicode characters
/// - Required model files are missing or invalid
pub
async
fn
from_local_path
(
local_root_dir
:
impl
AsRef
<
Path
>
)
->
anyhow
::
Result
<
Self
>
{
pub
async
fn
from_local_path
(
local_root_dir
:
impl
AsRef
<
Path
>
,
model_name
:
Option
<
String
>
,
)
->
anyhow
::
Result
<
Self
>
{
let
local_root_dir
=
local_root_dir
.as_ref
();
check_valid_local_repo_path
(
local_root_dir
)
?
;
let
repo_id
=
local_root_dir
...
...
@@ -46,11 +49,14 @@ impl ModelDeploymentCard {
.to_str
()
.ok_or_else
(||
anyhow
::
anyhow!
(
"Path contains invalid Unicode"
))
?
.to_string
();
let
model_name
=
local_root_dir
let
model_name
=
model_name
.unwrap_or
(
local_root_dir
.file_name
()
.and_then
(|
n
|
n
.to_str
())
.ok_or_else
(||
anyhow
::
anyhow!
(
"Invalid model directory name"
))
?
;
Self
::
from_repo
(
&
repo_id
,
model_name
)
.await
.ok_or_else
(||
anyhow
::
anyhow!
(
"Invalid model directory name"
))
?
.to_string
(),
);
Self
::
from_repo
(
&
repo_id
,
&
model_name
)
.await
}
/// TODO: This will be implemented after nova-hub is integrated with the model-card
...
...
llm/rust/triton-llm/src/preprocessor.rs
0 → 100644
View file @
4f6f63cd
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! The Preprocessor consists of the following modules
//!
//! - `translation`: This module converts the allowed Ingress message types to the corresponding
//! internal representation.
//! - `apply`: This module applies ModelConfig defaults to any empty optional fields specified
//! - `prompt`: This module applies any prompt template logic to the internal Request object.
//! - `tokenize`: This module tokenizes the formatted prompt string and returns the token ids.
//!
//! The Preprocessor will accept any IngressRequest and transform it to a BackendRequest.
pub
mod
prompt
;
pub
mod
tools
;
use
anyhow
::
Result
;
use
futures
::
stream
::{
self
,
StreamExt
};
use
prompt
::
OAIPromptFormatter
;
use
std
::{
collections
::
HashMap
,
sync
::
Arc
};
use
tracing
;
use
crate
::
model_card
::
model
::{
ModelDeploymentCard
,
ModelInfo
,
TokenizerKind
};
use
crate
::
preprocessor
::
prompt
::
OAIChatLikeRequest
;
use
triton_distributed
::
engine
::{
AsyncEngine
,
AsyncEngineContextProvider
,
ResponseStream
};
use
triton_distributed
::
pipeline
::{
async_trait
,
AsyncEngineContext
,
Error
,
ManyOut
,
Operator
,
SingleIn
,
};
use
triton_distributed
::
protocols
::
annotated
::{
Annotated
,
AnnotationsProvider
};
use
crate
::
protocols
::{
common
::{
SamplingOptionsProvider
,
StopConditionsProvider
},
openai
::{
chat_completions
::{
ChatCompletionRequest
,
ChatCompletionResponseDelta
},
completions
::{
CompletionRequest
,
CompletionResponse
},
nvext
::
NvExtProvider
,
DeltaGeneratorExt
,
},
};
use
crate
::
tokenizers
::{
traits
::
Tokenizer
,
HuggingFaceTokenizer
};
use
crate
::
preprocessor
::
prompt
::
PromptFormatter
;
pub
use
crate
::
protocols
::
common
::
llm_backend
::{
BackendInput
,
BackendOutput
};
pub
const
ANNOTATION_FORMATTED_PROMPT
:
&
str
=
"formatted_prompt"
;
pub
const
ANNOTATION_TOKEN_IDS
:
&
str
=
"token_ids"
;
pub
struct
OpenAIPreprocessor
{
mdcsum
:
String
,
formatter
:
Arc
<
dyn
OAIPromptFormatter
>
,
tokenizer
:
Arc
<
dyn
Tokenizer
>
,
model_info
:
Arc
<
dyn
ModelInfo
>
,
}
impl
OpenAIPreprocessor
{
pub
async
fn
new
(
mdc
:
ModelDeploymentCard
)
->
Result
<
Arc
<
Self
>>
{
let
formatter
=
PromptFormatter
::
from_mdc
(
mdc
.clone
())
.await
?
;
let
PromptFormatter
::
OAI
(
formatter
)
=
formatter
;
let
tokenizer
=
match
&
mdc
.tokenizer
{
TokenizerKind
::
HfTokenizerJson
(
file
)
=>
HuggingFaceTokenizer
::
from_file
(
&
file
)
?
,
};
let
tokenizer
=
Arc
::
new
(
tokenizer
);
let
model_info
=
mdc
.model_info
.get_model_info
()
.await
?
;
let
mdcsum
=
mdc
.mdcsum
();
Ok
(
Arc
::
new
(
Self
{
formatter
,
tokenizer
,
model_info
,
mdcsum
,
}))
}
/// Translate a [`ChatCompletionRequest`] request to a common completion request.
/// Returns both the common completion request and a hashmap of annotations.
///
/// Annotations evaluated by this method include:
/// - `formatted_prompt`
/// - `token_ids`
pub
fn
preprocess_request
<
R
:
OAIChatLikeRequest
+
AnnotationsProvider
+
SamplingOptionsProvider
+
StopConditionsProvider
+
NvExtProvider
,
>
(
&
self
,
request
:
&
R
,
)
->
Result
<
(
BackendInput
,
HashMap
<
String
,
String
>
)
>
{
let
mut
annotations
=
HashMap
::
new
();
let
mut
builder
=
BackendInput
::
builder
();
let
use_raw_prompt
=
request
.nvext
()
.map_or
(
false
,
|
ext
|
ext
.use_raw_prompt
.unwrap_or
(
false
));
let
formatted_prompt
=
if
use_raw_prompt
{
match
request
.raw_prompt
()
{
Some
(
prompt
)
=>
prompt
,
None
=>
{
tracing
::
warn!
(
"Raw prompt requested but not available"
);
self
.formatter
.render
(
request
)
?
}
}
}
else
{
self
.formatter
.render
(
request
)
?
};
let
encoding
=
tokio
::
task
::
block_in_place
(||
self
.tokenizer
.encode
(
&
formatted_prompt
))
?
;
if
request
.has_annotation
(
ANNOTATION_FORMATTED_PROMPT
)
{
annotations
.insert
(
ANNOTATION_FORMATTED_PROMPT
.to_string
(),
formatted_prompt
);
}
if
request
.has_annotation
(
ANNOTATION_TOKEN_IDS
)
{
annotations
.insert
(
ANNOTATION_TOKEN_IDS
.to_string
(),
serde_json
::
to_string
(
&
encoding
.token_ids
)
?
,
);
}
let
mut
stop_conditions
=
request
.extract_stop_conditions
()
?
;
// todo - pull this from the mdc default sampling/stop params
if
stop_conditions
.max_tokens
.is_none
()
{
stop_conditions
.max_tokens
=
Some
(
64
);
}
if
let
Some
(
stop_tokens
)
=
&
mut
stop_conditions
.stop_token_ids_hidden
{
for
eos_token
in
self
.model_info
.eos_token_ids
()
{
if
!
stop_tokens
.contains
(
&
eos_token
)
{
stop_tokens
.push
(
eos_token
);
}
}
}
else
{
stop_conditions
.stop_token_ids_hidden
=
Some
(
self
.model_info
.eos_token_ids
());
}
// apply ignore eos if not already set
stop_conditions
.apply_ignore_eos
();
if
!
stop_conditions
.ignore_eos
.unwrap_or
(
false
)
{
builder
.eos_token_ids
(
self
.model_info
.eos_token_ids
());
}
builder
.token_ids
(
encoding
.token_ids
);
builder
.sampling_options
(
request
.extract_sampling_options
()
?
);
builder
.stop_conditions
(
stop_conditions
);
builder
.annotations
(
request
.annotations
()
.unwrap_or_default
());
builder
.mdc_sum
(
Some
(
self
.mdcsum
.clone
()));
Ok
((
builder
.build
()
?
,
annotations
))
}
pub
fn
transform_postprocessor_stream
<
Resp
:
Send
+
Sync
+
'static
+
std
::
fmt
::
Debug
>
(
stream
:
ManyOut
<
Annotated
<
BackendOutput
>>
,
generator
:
Box
<
dyn
DeltaGeneratorExt
<
Resp
>>
,
)
->
ManyOut
<
Annotated
<
Resp
>>
{
let
context
=
stream
.context
();
struct
State
<
Resp
:
Send
+
Sync
+
'static
+
std
::
fmt
::
Debug
>
{
response_stream
:
ManyOut
<
Annotated
<
BackendOutput
>>
,
response_generator
:
Box
<
dyn
DeltaGeneratorExt
<
Resp
>>
,
context
:
Arc
<
dyn
AsyncEngineContext
>
,
cancelled
:
bool
,
}
let
state
=
State
{
response_stream
:
stream
,
response_generator
:
generator
,
context
:
context
.clone
(),
cancelled
:
false
,
};
// transform the common response stream into a chat response stream
let
stream
=
stream
::
unfold
(
state
,
|
mut
inner
|
{
async
move
{
if
let
Some
(
response
)
=
inner
.response_stream
.next
()
.await
{
if
inner
.cancelled
{
tracing
::
debug!
(
request_id
=
inner
.context
.id
(),
"Cancellation issued last message; closing stream"
);
return
None
;
}
tracing
::
trace!
(
request_id
=
inner
.context
.id
(),
"Processing common response: {:?}"
,
response
);
let
response
=
response
.map_data
(|
data
|
{
inner
.response_generator
.choice_from_postprocessor
(
data
)
.inspect_err
(|
e
|
{
tracing
::
error!
(
request_id
=
inner
.context
.id
(),
"Error processing common response: {:?}"
,
e
);
inner
.cancelled
=
true
;
inner
.context
.stop_generating
();
})
.map_err
(|
e
|
e
.to_string
())
});
tracing
::
trace!
(
request_id
=
inner
.context
.id
(),
"OpenAI ChatCompletionResponseDelta: {:?}"
,
response
);
Some
((
response
,
inner
))
}
else
{
// stream closed with out graceful closure
// we did not detect an is_finished/completed message
// Ok(None)
None
}
}
});
ResponseStream
::
new
(
Box
::
pin
(
stream
),
context
)
}
}
// for pals, we do not want to add the generation prompt to the formatted prompt
// we also need to know if the template support this add_generation_prompt bool
// any prompt template that does not support this should return an error
// oob - we should update any prompt template that does not support this to support it
#[async_trait]
impl
Operator
<
SingleIn
<
ChatCompletionRequest
>
,
ManyOut
<
Annotated
<
ChatCompletionResponseDelta
>>
,
SingleIn
<
BackendInput
>
,
ManyOut
<
Annotated
<
BackendOutput
>>
,
>
for
OpenAIPreprocessor
{
async
fn
generate
(
&
self
,
request
:
SingleIn
<
ChatCompletionRequest
>
,
next
:
Arc
<
dyn
AsyncEngine
<
SingleIn
<
BackendInput
>
,
ManyOut
<
Annotated
<
BackendOutput
>>
,
Error
>
,
>
,
)
->
Result
<
ManyOut
<
Annotated
<
ChatCompletionResponseDelta
>>
,
Error
>
{
// unpack the request
let
(
request
,
context
)
=
request
.into_parts
();
// create a response generator
let
response_generator
=
request
.response_generator
();
let
mut
response_generator
=
Box
::
new
(
response_generator
);
// convert the chat completion request to a common completion request
let
(
common_request
,
annotations
)
=
self
.preprocess_request
(
&
request
)
?
;
// update isl
response_generator
.update_isl
(
common_request
.token_ids
.len
()
as
i32
);
// repack the common completion request
let
common_request
=
context
.map
(|
_
|
common_request
);
// create a stream of annotations this will be prepend to the response stream
let
annotations
:
Vec
<
Annotated
<
ChatCompletionResponseDelta
>>
=
annotations
.into_iter
()
.flat_map
(|(
k
,
v
)|
Annotated
::
from_annotation
(
k
,
&
v
))
.collect
();
let
annotations_stream
=
stream
::
iter
(
annotations
);
// forward the common completion request to the next operator
let
response_stream
=
next
.generate
(
common_request
)
.await
?
;
// transform the postprocessor stream
let
stream
=
Self
::
transform_postprocessor_stream
(
response_stream
,
response_generator
);
let
context
=
stream
.context
();
// prepend the annotations to the response stream
let
stream
=
annotations_stream
.chain
(
stream
);
// return the response stream
Ok
(
ResponseStream
::
new
(
Box
::
pin
(
stream
),
context
))
}
}
#[async_trait]
impl
Operator
<
SingleIn
<
CompletionRequest
>
,
ManyOut
<
Annotated
<
CompletionResponse
>>
,
SingleIn
<
BackendInput
>
,
ManyOut
<
Annotated
<
BackendOutput
>>
,
>
for
OpenAIPreprocessor
{
async
fn
generate
(
&
self
,
request
:
SingleIn
<
CompletionRequest
>
,
next
:
Arc
<
dyn
AsyncEngine
<
SingleIn
<
BackendInput
>
,
ManyOut
<
Annotated
<
BackendOutput
>>
,
Error
>
,
>
,
)
->
Result
<
ManyOut
<
Annotated
<
CompletionResponse
>>
,
Error
>
{
// unpack the request
let
(
request
,
context
)
=
request
.into_parts
();
// create a response generator
let
response_generator
=
request
.response_generator
();
let
mut
response_generator
=
Box
::
new
(
response_generator
);
// convert the chat completion request to a common completion request
let
(
common_request
,
annotations
)
=
self
.preprocess_request
(
&
request
)
?
;
// update isl
response_generator
.update_isl
(
common_request
.token_ids
.len
()
as
i32
);
// repack the common completion request
let
common_request
=
context
.map
(|
_
|
common_request
);
// create a stream of annotations this will be prepend to the response stream
let
annotations
:
Vec
<
Annotated
<
CompletionResponse
>>
=
annotations
.into_iter
()
.flat_map
(|(
k
,
v
)|
Annotated
::
from_annotation
(
k
,
&
v
))
.collect
();
let
annotations_stream
=
stream
::
iter
(
annotations
);
// forward the common completion request to the next operator
let
response_stream
=
next
.generate
(
common_request
)
.await
?
;
// transform the postprocessor stream
let
stream
=
Self
::
transform_postprocessor_stream
(
response_stream
,
response_generator
);
let
context
=
stream
.context
();
// prepend the annotations to the response stream
let
stream
=
annotations_stream
.chain
(
stream
);
// return the response stream
Ok
(
ResponseStream
::
new
(
Box
::
pin
(
stream
),
context
))
}
}
llm/rust/triton-llm/src/preprocessor/prompt.rs
0 → 100644
View file @
4f6f63cd
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Prompt Formatting Module
//!
//! Handles formatting of LLM request prompts, including:
//! - Chat template rendering
//! - Tool usage formatting
//! - Generation prompt handling
//!
//! The module supports different prompt formatting strategies through the
//! PromptFormatter
// TODO:
// 1. Query if `add_generation_prompt` is present in the prompt template
// 2. Support for models with add_generation_prompt:
// - PALS (Prefix-Assisted Language Sampling)
// - Continuation - Detected on user turns, where we can return
// partial assistant responses without add_generation_prompt
use
anyhow
::
Result
;
use
minijinja
::
value
::
Value
;
use
std
::
sync
::
Arc
;
mod
template
;
pub
use
template
::
ContextMixins
;
/// Trait that defines a request that can map to an OpenAI-like request.
pub
trait
OAIChatLikeRequest
{
fn
messages
(
&
self
)
->
Value
;
fn
tools
(
&
self
)
->
Option
<
Value
>
{
None
}
fn
tool_choice
(
&
self
)
->
Option
<
Value
>
{
None
}
fn
should_add_generation_prompt
(
&
self
)
->
bool
;
}
pub
trait
OAIPromptFormatter
:
Send
+
Sync
+
'static
{
fn
supports_add_generation_prompt
(
&
self
)
->
bool
;
fn
render
(
&
self
,
req
:
&
dyn
OAIChatLikeRequest
)
->
Result
<
String
>
;
}
pub
enum
PromptFormatter
{
OAI
(
Arc
<
dyn
OAIPromptFormatter
>
),
}
llm/rust/triton-llm/src/preprocessor/prompt/template.rs
0 → 100644
View file @
4f6f63cd
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
std
::{
collections
::
HashSet
,
sync
::
Arc
};
use
anyhow
::{
Ok
,
Result
};
use
minijinja
::
Environment
;
use
crate
::
model_card
::
model
::{
ModelDeploymentCard
,
PromptContextMixin
,
PromptFormatterArtifact
};
mod
context
;
mod
formatters
;
mod
oai
;
mod
tokcfg
;
use
super
::{
OAIChatLikeRequest
,
OAIPromptFormatter
,
PromptFormatter
};
use
tokcfg
::{
raise_exception
,
tojson
,
ChatTemplate
as
HfTokenizerConfig
};
impl
PromptFormatter
{
pub
async
fn
from_mdc
(
mdc
:
ModelDeploymentCard
)
->
Result
<
PromptFormatter
>
{
match
mdc
.prompt_formatter
.ok_or
(
anyhow
::
anyhow!
(
"MDC does not contain a prompt formatter"
))
?
{
PromptFormatterArtifact
::
HfTokenizerConfigJson
(
file
)
=>
{
let
content
=
std
::
fs
::
read_to_string
(
file
)
?
;
let
config
:
HfTokenizerConfig
=
serde_json
::
from_str
(
&
content
)
?
;
let
formatter
=
HfTokenizerConfigJsonFormatter
::
new
(
config
,
mdc
.prompt_context
.map_or
(
ContextMixins
::
default
(),
|
x
|
ContextMixins
::
new
(
&
x
)),
)
?
;
Ok
(
Self
::
OAI
(
Arc
::
new
(
formatter
)))
}
}
}
}
/// Chat Template Jinja Renderer
///
/// Manages a Jinja environment with registered templates for chat formatting.
/// Handles two types of ChatTemplateValue templates:
///
/// 1. String template: Registered as the 'default' template
/// 2. Map template: Contains 'tool_use' and/or 'default' templates
/// - tool_use: Template for tool-based interactions
/// - default: Template for standard chat interactions
/// If the map contains both keys, the `tool_use` template is registered as the `tool_use` template
/// and the `default` template is registered as the `default` template.
struct
JinjaEnvironment
{
env
:
Environment
<
'static
>
,
}
/// Formatter for HuggingFace tokenizer config JSON templates
///
/// Implements chat template rendering based on HuggingFace's tokenizer_config.json format.
/// Supports:
/// - Tool usage templates
/// - Generation prompts
/// - Context mixins for template customization
#[derive(Debug)]
struct
HfTokenizerConfigJsonFormatter
{
env
:
Environment
<
'static
>
,
config
:
HfTokenizerConfig
,
mixins
:
Arc
<
ContextMixins
>
,
supports_add_generation_prompt
:
bool
,
}
// /// OpenAI Standard Prompt Formatter
// pub trait StandardPromptFormatter {
// fn render(&self, context: &impl StandardPromptContext) -> Result<String>;
// }
// pub trait StandardPromptContext {
// fn messages(&self) -> Value;
// fn tools(&self) -> Option<Value>;
// }
#[derive(Debug,
Clone,
Default)]
pub
struct
ContextMixins
{
context_mixins
:
HashSet
<
PromptContextMixin
>
,
}
llm/rust/triton-llm/src/preprocessor/prompt/template/context.rs
0 → 100644
View file @
4f6f63cd
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
super
::{
ContextMixins
,
PromptContextMixin
};
use
chrono
::{
DateTime
,
Utc
};
use
minijinja
::
value
::{
Object
,
Value
};
use
std
::
sync
::
Arc
;
impl
Object
for
ContextMixins
{
fn
get_value
(
self
:
&
Arc
<
Self
>
,
field
:
&
Value
)
->
Option
<
Value
>
{
match
field
.as_str
()
?
{
"datetime"
=>
self
.datetime
(),
_
=>
None
,
}
}
}
impl
ContextMixins
{
pub
fn
new
(
allowed_mixins
:
&
[
PromptContextMixin
])
->
Self
{
ContextMixins
{
context_mixins
:
allowed_mixins
.iter
()
.cloned
()
.collect
(),
}
}
/// Implements the `datetime` context mixin.
/// Different mixins can be implemented here for the same key.
/// We need to valiate that multiple mixins do not conflict with each other.
fn
datetime
(
&
self
)
->
Option
<
Value
>
{
if
self
.context_mixins
.contains
(
&
PromptContextMixin
::
Llama3DateTime
)
{
let
now
=
chrono
::
Utc
::
now
();
Some
(
Value
::
from
(
llama3_datetime
(
now
)))
}
else
{
None
}
}
}
fn
llama3_datetime
(
datetime
:
DateTime
<
Utc
>
)
->
String
{
datetime
.format
(
"%d, %B, %Y"
)
.to_string
()
}
llm/rust/triton-llm/src/preprocessor/prompt/template/formatters.rs
0 → 100644
View file @
4f6f63cd
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
super
::
*
;
use
either
::
Either
;
use
tracing
;
impl
JinjaEnvironment
{
fn
env
(
self
)
->
Environment
<
'static
>
{
self
.env
}
}
impl
Default
for
JinjaEnvironment
{
fn
default
()
->
Self
{
let
mut
env
=
Environment
::
new
();
env
.set_lstrip_blocks
(
true
);
env
.set_trim_blocks
(
true
);
JinjaEnvironment
{
env
}
}
}
impl
HfTokenizerConfigJsonFormatter
{
pub
fn
new
(
config
:
HfTokenizerConfig
,
mixins
:
ContextMixins
)
->
Result
<
Self
>
{
let
mut
env
=
JinjaEnvironment
::
default
()
.env
();
let
chat_template
=
config
.chat_template
.as_ref
()
.ok_or
(
anyhow
::
anyhow!
(
"chat_template field is required in the tokenizer_config.json file"
))
?
;
// add pycompat
// todo: should we use this: minijinja_contrib::add_to_environment(&mut env);
env
.set_unknown_method_callback
(
minijinja_contrib
::
pycompat
::
unknown_method_callback
);
// add custom functions and filters
env
.add_function
(
"raise_exception"
,
raise_exception
);
env
.add_filter
(
"tojson"
,
tojson
);
let
mut
supports_add_generation_prompt
=
None
;
match
&
chat_template
.0
{
Either
::
Left
(
x
)
=>
{
if
x
.contains
(
"add_generation_prompt"
)
{
tracing
::
debug!
(
"Chat template contains `add_generation_prompt` key. This model supports add_generation_prompt."
);
supports_add_generation_prompt
=
Some
(
true
);
}
env
.add_template_owned
(
"default"
,
x
.to_string
())
?
;
env
.add_template_owned
(
"tool_use"
,
x
.to_string
())
?
;
}
Either
::
Right
(
map
)
=>
{
for
t
in
map
{
for
(
k
,
v
)
in
t
.iter
()
{
if
v
.contains
(
"add_generation_prompt"
)
{
match
supports_add_generation_prompt
{
Some
(
true
)
|
None
=>
{
tracing
::
debug!
(
"Chat template contains `add_generation_prompt` key. This model supports add_generation_prompt."
);
supports_add_generation_prompt
=
Some
(
true
);
}
Some
(
false
)
=>
{
tracing
::
warn!
(
"Not all templates contain `add_generation_prompt` key. This model does not support add_generation_prompt."
);
}
}
}
else
{
supports_add_generation_prompt
=
Some
(
false
);
}
env
.add_template_owned
(
k
.to_string
(),
v
.to_string
())
?
;
}
}
if
env
.templates
()
.count
()
==
0
{
anyhow
::
bail!
(
"Chat template does not contain a `tool_use` or `default` key. Please ensure it contains at least a `default` key, although `tool_use` should be specified for using tools."
);
}
}
}
Ok
(
HfTokenizerConfigJsonFormatter
{
env
,
config
,
mixins
:
Arc
::
new
(
mixins
),
supports_add_generation_prompt
:
supports_add_generation_prompt
.unwrap_or
(
false
),
})
}
}
// impl JinjaEnvironment {
// /// Renders the template with the provided messages.
// /// This function reuses the pre-compiled template for efficiency.
// pub fn render(&self, template_id: &str, ctx: &dyn erased_serde::Serialize) -> Result<String> {
// let tmpl = self.env.get_template(template_id)?;
// Ok(tmpl.render(ctx)?)
// }
// // fn apply_tool_template()
// }
llm/rust/triton-llm/src/preprocessor/prompt/template/oai.rs
0 → 100644
View file @
4f6f63cd
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
super
::
*
;
use
minijinja
::{
context
,
value
::
Value
};
use
crate
::
protocols
::
openai
::{
chat_completions
::{
ChatCompletionMessage
,
ChatCompletionRequest
,
Content
,
MessageRole
},
completions
::
CompletionRequest
,
};
use
tracing
;
impl
OAIChatLikeRequest
for
ChatCompletionRequest
{
fn
messages
(
&
self
)
->
Value
{
Value
::
from_serialize
(
&
self
.messages
)
}
fn
tools
(
&
self
)
->
Option
<
Value
>
{
if
self
.tools
.is_none
()
{
None
}
else
{
Some
(
Value
::
from_serialize
(
&
self
.tools
))
}
}
fn
tool_choice
(
&
self
)
->
Option
<
Value
>
{
if
self
.tool_choice
.is_none
()
{
None
}
else
{
Some
(
Value
::
from_serialize
(
&
self
.tool_choice
))
}
}
fn
should_add_generation_prompt
(
&
self
)
->
bool
{
if
let
Some
(
last
)
=
self
.messages
.last
()
{
last
.role
==
MessageRole
::
user
}
else
{
true
}
}
}
impl
OAIChatLikeRequest
for
CompletionRequest
{
fn
messages
(
&
self
)
->
Value
{
let
message
=
ChatCompletionMessage
{
role
:
MessageRole
::
user
,
content
:
Content
::
Text
(
self
.prompt
.clone
()),
name
:
None
,
};
Value
::
from_serialize
(
vec!
[
message
])
}
fn
should_add_generation_prompt
(
&
self
)
->
bool
{
true
}
}
impl
OAIPromptFormatter
for
HfTokenizerConfigJsonFormatter
{
fn
supports_add_generation_prompt
(
&
self
)
->
bool
{
self
.supports_add_generation_prompt
}
fn
render
(
&
self
,
req
:
&
dyn
OAIChatLikeRequest
)
->
Result
<
String
>
{
let
mixins
=
Value
::
from_dyn_object
(
self
.mixins
.clone
());
let
tools
=
req
.tools
();
let
has_tools
=
tools
.is_some
();
let
add_generation_prompt
=
req
.should_add_generation_prompt
();
tracing
::
trace!
(
"Rendering prompt with tools: {:?}, add_generation_prompt: {}"
,
has_tools
,
add_generation_prompt
);
let
ctx
=
context!
{
messages
=>
req
.messages
(),
tools
=>
tools
,
bos_token
=>
self
.config
.bos_tok
(),
eos_token
=>
self
.config
.eos_tok
(),
unk_token
=>
self
.config
.unk_tok
(),
add_generation_prompt
=>
add_generation_prompt
,
..
mixins
};
let
ctx
=
context!
{
..
ctx
,
..
context!
{
}};
let
tmpl
=
if
has_tools
{
self
.env
.get_template
(
"tool_use"
)
?
}
else
{
self
.env
.get_template
(
"default"
)
?
};
Ok
(
tmpl
.render
(
&
ctx
)
?
)
}
}
llm/rust/triton-llm/src/preprocessor/prompt/template/tokcfg.rs
0 → 100644
View file @
4f6f63cd
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//based on: https://github.com/EricLBuehler/mistral.rs/blob/d970bb5feb863acf8e8ec90de97e18221fb959f1/mistralrs-core/src/pipeline/chat_template.rs
use
std
::
collections
::
HashMap
;
use
either
::
Either
;
use
minijinja
::{
value
::
Kwargs
,
Error
,
ErrorKind
,
Value
};
use
serde
::{
Deserialize
,
Serialize
};
#[allow(dead_code)]
#[derive(Debug,
Deserialize)]
pub
struct
AddedTokensDecoder
{
__
type
:
Option
<
String
>
,
pub
content
:
String
,
lstrip
:
bool
,
normalized
:
bool
,
rstrip
:
bool
,
single_word
:
bool
,
special
:
Option
<
bool
>
,
}
pub
fn
raise_exception
(
msg
:
String
)
->
Result
<
String
,
minijinja
::
Error
>
{
Err
(
minijinja
::
Error
::
new
(
ErrorKind
::
InvalidOperation
,
msg
))
}
#[derive(Debug,
Deserialize)]
pub
struct
BeginEndUnkTok
(
#[serde(with
=
"either::serde_untagged"
)]
pub
Either
<
String
,
AddedTokensDecoder
>
,
);
/// Support older tool use patterns where the tool use template was separate from the default/chat template.
/// Modern patterns use a single template with a `tool_use` key, e.g.
///
/// ```jinja
/// {%- if tools is not none and tool_choice is not none %}
/// ```
#[derive(Debug,
Deserialize)]
pub
struct
ChatTemplateValue
(
#[serde(with
=
"either::serde_untagged"
)]
pub
Either
<
String
,
Vec
<
HashMap
<
String
,
String
>>>
,
);
/// If present, pad_token is usually a single value. Deepseek R1 and it's distill's use a map.
#[allow(dead_code)]
#[derive(Debug,
Deserialize)]
pub
struct
PadTokenValue
(
#[serde(with
=
"either::serde_untagged"
)]
pub
Either
<
String
,
AddedTokensDecoder
>
,
);
#[allow(dead_code)]
#[derive(Debug,
Deserialize,
Default)]
/// Template for chat models including bos/eos/unk as well as the chat template.
pub
struct
ChatTemplate
{
add_bos_token
:
Option
<
bool
>
,
add_eos_token
:
Option
<
bool
>
,
added_tokens_decoder
:
Option
<
HashMap
<
String
,
AddedTokensDecoder
>>
,
additional_special_tokens
:
Option
<
Vec
<
String
>>
,
pub
bos_token
:
Option
<
BeginEndUnkTok
>
,
/// Jinja format [chat templating] for chat completion.
///
/// [chat templating]: https://huggingface.co/docs/transformers/chat_templating
pub
chat_template
:
Option
<
ChatTemplateValue
>
,
clean_up_tokenization_spaces
:
Option
<
bool
>
,
device_map
:
Option
<
String
>
,
pub
eos_token
:
Option
<
BeginEndUnkTok
>
,
legacy
:
Option
<
bool
>
,
model_max_length
:
Option
<
f64
>
,
pad_token
:
Option
<
PadTokenValue
>
,
sp_model_kwargs
:
Option
<
HashMap
<
String
,
String
>>
,
spaces_between_special_tokens
:
Option
<
bool
>
,
tokenizer_class
:
Option
<
String
>
,
truncation_size
:
Option
<
String
>
,
pub
unk_token
:
Option
<
BeginEndUnkTok
>
,
use_default_system_prompt
:
Option
<
bool
>
,
}
impl
ChatTemplate
{
// pub fn has_chat_template(&self) -> bool {
// self.chat_template.is_some()
// }
pub
fn
eos_tok
(
&
self
)
->
Option
<
String
>
{
match
self
.eos_token
.as_ref
()
?
.0
{
Either
::
Left
(
ref
lit
)
=>
Some
(
lit
.clone
()),
Either
::
Right
(
ref
added
)
=>
Some
(
added
.content
.clone
()),
}
}
pub
fn
bos_tok
(
&
self
)
->
Option
<
String
>
{
match
self
.bos_token
.as_ref
()
?
.0
{
Either
::
Left
(
ref
lit
)
=>
Some
(
lit
.clone
()),
Either
::
Right
(
ref
added
)
=>
Some
(
added
.content
.clone
()),
}
}
pub
fn
unk_tok
(
&
self
)
->
Option
<
String
>
{
match
self
.unk_token
.as_ref
()
?
.0
{
Either
::
Left
(
ref
lit
)
=>
Some
(
lit
.clone
()),
Either
::
Right
(
ref
added
)
=>
Some
(
added
.content
.clone
()),
}
}
}
#[allow(dead_code)]
#[derive(Debug,
Deserialize)]
pub
struct
GenerationConfig
{
#[serde(with
=
"either::serde_untagged"
)]
bos_token_id
:
Either
<
u32
,
Vec
<
u32
>>
,
#[serde(with
=
"either::serde_untagged"
)]
eos_token_id
:
Either
<
u32
,
Vec
<
u32
>>
,
}
pub
fn
tojson
(
value
:
Value
,
kwargs
:
Kwargs
)
->
Result
<
Value
,
Error
>
{
if
let
Ok
(
indent
)
=
kwargs
.get
(
"indent"
)
{
let
mut
buf
=
Vec
::
new
();
let
repeat
=
b
" "
.repeat
(
indent
);
let
formatter
=
serde_json
::
ser
::
PrettyFormatter
::
with_indent
(
&
repeat
);
let
mut
ser
=
serde_json
::
Serializer
::
with_formatter
(
&
mut
buf
,
formatter
);
value
.serialize
(
&
mut
ser
)
.unwrap
();
String
::
from_utf8
(
buf
)
.map_err
(|
err
|
{
Error
::
new
(
ErrorKind
::
BadSerialization
,
"cannot serialize to JSON"
)
.with_source
(
err
)
})
}
else
{
serde_json
::
to_string
(
&
value
)
.map_err
(|
err
|
{
Error
::
new
(
ErrorKind
::
BadSerialization
,
"cannot serialize to JSON"
)
.with_source
(
err
)
})
}
.map_err
(|
err
|
{
Error
::
new
(
ErrorKind
::
InvalidOperation
,
"cannot serialize to JSON"
)
.with_source
(
err
)
})
.map
(|
s
|
{
// When this filter is used the return value is safe for both HTML and JSON
let
mut
rv
=
String
::
with_capacity
(
s
.len
());
for
c
in
s
.chars
()
{
match
c
{
'<'
=>
rv
.push_str
(
"
\\
u003c"
),
'>'
=>
rv
.push_str
(
"
\\
u003e"
),
'&'
=>
rv
.push_str
(
"
\\
u0026"
),
'\''
=>
rv
.push_str
(
"
\\
u0027"
),
_
=>
rv
.push
(
c
),
}
}
Value
::
from_safe_string
(
rv
)
})
}
llm/rust/triton-llm/src/preprocessor/tools.rs
0 → 100644
View file @
4f6f63cd
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
mod
request
;
mod
response
;
pub
use
request
::
*
;
pub
use
response
::
*
;
use
serde_json
::
Value
;
use
std
::
collections
::
HashMap
;
use
uuid
::
Uuid
;
/// Matches and processes tool calling patterns in LLM responses
///
/// Supports multiple formats for tool calls:
/// - Single/multiple function calls with parameters/arguments
/// - Auto or user selected tool usage
pub
struct
ToolCallingMatcher
{
tool_choice
:
ToolChoice
,
}
// Same as CalledFunction with named parameters
#[derive(Clone,
Debug,
serde::Serialize,
serde::Deserialize)]
pub
struct
CalledFunctionParameters
{
pub
name
:
String
,
pub
parameters
:
HashMap
<
String
,
Value
>
,
}
// Same as CalledFunction with named parameters
#[derive(Clone,
Debug,
serde::Serialize,
serde::Deserialize)]
pub
struct
CalledFunctionArguments
{
pub
name
:
String
,
pub
arguments
:
HashMap
<
String
,
Value
>
,
}
impl
ToolCallingMatcher
{
pub
fn
new
(
tool_choice
:
ToolChoice
)
->
anyhow
::
Result
<
Self
>
{
Ok
(
Self
{
tool_choice
})
}
pub
fn
get_call
(
&
self
,
message
:
&
str
)
->
anyhow
::
Result
<
Vec
<
ToolCallResponse
>>
{
if
matches!
(
self
.tool_choice
,
ToolChoice
::
None
)
{
return
Ok
(
Vec
::
new
());
}
if
let
Ok
(
deser
)
=
serde_json
::
from_str
::
<
CalledFunctionParameters
>
(
message
)
{
let
id
=
format!
(
"call-{}"
,
Uuid
::
new_v4
());
Ok
(
vec!
[
ToolCallResponse
{
id
,
tp
:
ToolCallType
::
Function
,
function
:
CalledFunction
{
name
:
deser
.name
,
arguments
:
serde_json
::
to_string
(
&
deser
.parameters
)
?
,
},
}])
}
else
if
let
Ok
(
deser
)
=
serde_json
::
from_str
::
<
Vec
<
CalledFunctionParameters
>>
(
message
)
{
Ok
(
deser
.into_iter
()
.map
(|
deser
|
{
let
id
=
format!
(
"call-{}"
,
Uuid
::
new_v4
());
Ok
(
ToolCallResponse
{
id
,
tp
:
ToolCallType
::
Function
,
function
:
CalledFunction
{
name
:
deser
.name
,
arguments
:
serde_json
::
to_string
(
&
deser
.parameters
)
?
,
},
})
})
.collect
::
<
anyhow
::
Result
<
Vec
<
_
>>>
()
?
)
}
else
if
let
Ok
(
deser
)
=
serde_json
::
from_str
::
<
CalledFunctionArguments
>
(
message
)
{
let
id
=
format!
(
"call-{}"
,
Uuid
::
new_v4
());
Ok
(
vec!
[
ToolCallResponse
{
id
,
tp
:
ToolCallType
::
Function
,
function
:
CalledFunction
{
name
:
deser
.name
,
arguments
:
serde_json
::
to_string
(
&
deser
.arguments
)
?
,
},
}])
}
else
if
let
Ok
(
deser
)
=
serde_json
::
from_str
::
<
Vec
<
CalledFunctionArguments
>>
(
message
)
{
Ok
(
deser
.into_iter
()
.map
(|
deser
|
{
let
id
=
format!
(
"call-{}"
,
Uuid
::
new_v4
());
Ok
(
ToolCallResponse
{
id
,
tp
:
ToolCallType
::
Function
,
function
:
CalledFunction
{
name
:
deser
.name
,
arguments
:
serde_json
::
to_string
(
&
deser
.arguments
)
?
,
},
})
})
.collect
::
<
anyhow
::
Result
<
Vec
<
_
>>>
()
?
)
}
else
{
if
matches!
(
self
.tool_choice
,
ToolChoice
::
Tool
(
_
))
{
anyhow
::
bail!
(
"Tool choice was required but no tools were called."
)
}
Ok
(
Vec
::
new
())
}
}
}
llm/rust/triton-llm/src/preprocessor/tools/request.rs
0 → 100644
View file @
4f6f63cd
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
std
::
collections
::
HashMap
;
use
serde_json
::
Value
;
#[derive(Clone,
Debug,
serde::Deserialize,
serde::Serialize)]
pub
enum
ToolType
{
#[serde(rename
=
"function"
)]
Function
,
}
#[derive(Clone,
Debug,
serde::Deserialize,
serde::Serialize)]
pub
enum
ToolChoice
{
#[serde(rename
=
"none"
)]
/// Disallow selection of tools.
None
,
#[serde(rename
=
"auto"
)]
/// Allow automatic selection of any given tool, or none.
Auto
,
#[serde(untagged)]
/// Force selection of a given tool.
Tool
(
Tool
),
}
#[derive(Clone,
Debug,
serde::Deserialize,
serde::Serialize)]
pub
struct
Function
{
pub
description
:
Option
<
String
>
,
pub
name
:
String
,
pub
parameters
:
Option
<
HashMap
<
String
,
Value
>>
,
}
#[derive(Clone,
Debug,
serde::Deserialize,
serde::Serialize)]
pub
struct
Tool
{
#[serde(rename
=
"type"
)]
pub
tp
:
ToolType
,
pub
function
:
Function
,
}
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment