Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
86aff237
Commit
86aff237
authored
Feb 26, 2025
by
Paul Hendricks
Committed by
GitHub
Feb 26, 2025
Browse files
refactor: using async_openai
Co-authored-by:
Graham King
<
grahamk@nvidia.com
>
parent
d694ca6e
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2079 additions
and
1873 deletions
+2079
-1873
applications/llm/count/Cargo.lock
applications/llm/count/Cargo.lock
+325
-0
examples/rust/Cargo.lock
examples/rust/Cargo.lock
+325
-0
launch/tio/Cargo.lock
launch/tio/Cargo.lock
+194
-0
launch/tio/Cargo.toml
launch/tio/Cargo.toml
+1
-0
launch/tio/src/input/text.rs
launch/tio/src/input/text.rs
+42
-25
launch/tio/src/output/echo_full.rs
launch/tio/src/output/echo_full.rs
+25
-11
lib/bindings/c/Cargo.lock
lib/bindings/c/Cargo.lock
+325
-0
lib/bindings/python/Cargo.lock
lib/bindings/python/Cargo.lock
+325
-0
lib/llm/Cargo.lock
lib/llm/Cargo.lock
+196
-106
lib/llm/Cargo.toml
lib/llm/Cargo.toml
+1
-0
lib/llm/src/engines/mistralrs.rs
lib/llm/src/engines/mistralrs.rs
+28
-17
lib/llm/src/http/service/openai.rs
lib/llm/src/http/service/openai.rs
+8
-4
lib/llm/src/preprocessor.rs
lib/llm/src/preprocessor.rs
+1
-1
lib/llm/src/preprocessor/prompt/template/oai.rs
lib/llm/src/preprocessor/prompt/template/oai.rs
+27
-16
lib/llm/src/protocols/common.rs
lib/llm/src/protocols/common.rs
+27
-0
lib/llm/src/protocols/openai.rs
lib/llm/src/protocols/openai.rs
+5
-4
lib/llm/src/protocols/openai/chat_completions.rs
lib/llm/src/protocols/openai/chat_completions.rs
+35
-1553
lib/llm/src/protocols/openai/chat_completions/aggregator.rs
lib/llm/src/protocols/openai/chat_completions/aggregator.rs
+129
-96
lib/llm/src/protocols/openai/chat_completions/delta.rs
lib/llm/src/protocols/openai/chat_completions/delta.rs
+54
-34
lib/llm/src/protocols/openai/completions.rs
lib/llm/src/protocols/openai/completions.rs
+6
-6
No files found.
applications/llm/count/Cargo.lock
View file @
86aff237
...
...
@@ -151,6 +151,31 @@ version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4288f83726785267c6f2ef073a3d83dc3f9b81464e9f99898240cced85fce35a"
[[package]]
name = "async-openai"
version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d126927c78e1562d7e8473008ac8b082318c04d69e3a83e3495a563f8b84a66"
dependencies = [
"backoff",
"base64 0.22.1",
"bytes",
"derive_builder",
"eventsource-stream",
"futures",
"rand",
"reqwest",
"reqwest-eventsource",
"secrecy",
"serde",
"serde_json",
"thiserror 2.0.11",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
]
[[package]]
name = "async-stream"
version = "0.3.6"
...
...
@@ -320,6 +345,20 @@ dependencies = [
"tracing",
]
[[package]]
name = "backoff"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
dependencies = [
"futures-core",
"getrandom 0.2.15",
"instant",
"pin-project-lite",
"rand",
"tokio",
]
[[package]]
name = "backtrace"
version = "0.3.74"
...
...
@@ -995,6 +1034,17 @@ dependencies = [
"tower-service",
]
[[package]]
name = "eventsource-stream"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
dependencies = [
"futures-core",
"nom",
"pin-project-lite",
]
[[package]]
name = "fastrand"
version = "2.3.0"
...
...
@@ -1157,6 +1207,12 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]]
name = "futures-util"
version = "0.3.31"
...
...
@@ -1202,8 +1258,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
dependencies = [
"cfg-if 1.0.0",
"js-sys",
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
"wasm-bindgen",
]
[[package]]
...
...
@@ -1351,6 +1409,24 @@ dependencies = [
"want",
]
[[package]]
name = "hyper-rustls"
version = "0.27.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2"
dependencies = [
"futures-util",
"http",
"hyper",
"hyper-util",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tower-service",
]
[[package]]
name = "hyper-timeout"
version = "0.5.2"
...
...
@@ -1596,6 +1672,15 @@ version = "0.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb"
[[package]]
name = "instant"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222"
dependencies = [
"cfg-if 1.0.0",
]
[[package]]
name = "iovec"
version = "0.1.4"
...
...
@@ -1605,6 +1690,12 @@ dependencies = [
"libc",
]
[[package]]
name = "ipnet"
version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
[[package]]
name = "is_terminal_polyfill"
version = "1.70.1"
...
...
@@ -1824,6 +1915,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minijinja"
version = "2.7.0"
...
...
@@ -2527,6 +2628,58 @@ dependencies = [
"syn 2.0.98",
]
[[package]]
name = "quinn"
version = "0.11.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef"
dependencies = [
"bytes",
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash",
"rustls",
"socket2",
"thiserror 2.0.11",
"tokio",
"tracing",
]
[[package]]
name = "quinn-proto"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d"
dependencies = [
"bytes",
"getrandom 0.2.15",
"rand",
"ring",
"rustc-hash",
"rustls",
"rustls-pki-types",
"slab",
"thiserror 2.0.11",
"tinyvec",
"tracing",
"web-time",
]
[[package]]
name = "quinn-udp"
version = "0.5.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e46f3055866785f6b92bc6164b76be02ca8f2eb4b002c0354b28cf4c119e5944"
dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2",
"tracing",
"windows-sys 0.52.0",
]
[[package]]
name = "quote"
version = "1.0.38"
...
...
@@ -2661,6 +2814,68 @@ version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "reqwest"
version = "0.12.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da"
dependencies = [
"base64 0.22.1",
"bytes",
"futures-core",
"futures-util",
"http",
"http-body",
"http-body-util",
"hyper",
"hyper-rustls",
"hyper-util",
"ipnet",
"js-sys",
"log",
"mime",
"mime_guess",
"once_cell",
"percent-encoding",
"pin-project-lite",
"quinn",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pemfile",
"rustls-pki-types",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tokio-rustls",
"tokio-util",
"tower 0.5.2",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"wasm-streams",
"web-sys",
"windows-registry",
]
[[package]]
name = "reqwest-eventsource"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde"
dependencies = [
"eventsource-stream",
"futures-core",
"futures-timer",
"mime",
"nom",
"pin-project-lite",
"reqwest",
"thiserror 1.0.69",
]
[[package]]
name = "ring"
version = "0.17.11"
...
...
@@ -2681,6 +2896,12 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc-hash"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
[[package]]
name = "rustc_version"
version = "0.4.1"
...
...
@@ -2757,6 +2978,9 @@ name = "rustls-pki-types"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c"
dependencies = [
"web-time",
]
[[package]]
name = "rustls-webpki"
...
...
@@ -2805,6 +3029,16 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "secrecy"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a"
dependencies = [
"serde",
"zeroize",
]
[[package]]
name = "security-framework"
version = "2.11.1"
...
...
@@ -3123,6 +3357,9 @@ name = "sync_wrapper"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263"
dependencies = [
"futures-core",
]
[[package]]
name = "synstructure"
...
...
@@ -3261,6 +3498,21 @@ dependencies = [
"zerovec",
]
[[package]]
name = "tinyvec"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8"
dependencies = [
"tinyvec_macros",
]
[[package]]
name = "tinyvec_macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokenizers"
version = "0.21.0"
...
...
@@ -3612,6 +3864,7 @@ name = "triton-distributed-llm"
version = "0.2.0"
dependencies = [
"anyhow",
"async-openai",
"async-stream",
"async-trait",
"axum 0.8.1",
...
...
@@ -3740,6 +3993,12 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eeba86d422ce181a719445e51872fa30f1f7413b62becb52e95ec91aa262d85c"
[[package]]
name = "unicase"
version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
[[package]]
name = "unicode-ident"
version = "1.0.17"
...
...
@@ -3957,6 +4216,19 @@ dependencies = [
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-futures"
version = "0.4.50"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61"
dependencies = [
"cfg-if 1.0.0",
"js-sys",
"once_cell",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.100"
...
...
@@ -3989,6 +4261,29 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "wasm-streams"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
dependencies = [
"futures-util",
"js-sys",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]]
name = "web-sys"
version = "0.3.77"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "web-time"
version = "1.1.0"
...
...
@@ -4060,6 +4355,36 @@ dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-registry"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0"
dependencies = [
"windows-result",
"windows-strings",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-result"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e"
dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-strings"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10"
dependencies = [
"windows-result",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-sys"
version = "0.48.0"
...
...
examples/rust/Cargo.lock
View file @
86aff237
...
...
@@ -151,6 +151,31 @@ version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4288f83726785267c6f2ef073a3d83dc3f9b81464e9f99898240cced85fce35a"
[[package]]
name = "async-openai"
version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d126927c78e1562d7e8473008ac8b082318c04d69e3a83e3495a563f8b84a66"
dependencies = [
"backoff",
"base64 0.22.1",
"bytes",
"derive_builder",
"eventsource-stream",
"futures",
"rand",
"reqwest",
"reqwest-eventsource",
"secrecy",
"serde",
"serde_json",
"thiserror 2.0.11",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
]
[[package]]
name = "async-stream"
version = "0.3.6"
...
...
@@ -320,6 +345,20 @@ dependencies = [
"tracing",
]
[[package]]
name = "backoff"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
dependencies = [
"futures-core",
"getrandom 0.2.15",
"instant",
"pin-project-lite",
"rand",
"tokio",
]
[[package]]
name = "backtrace"
version = "0.3.74"
...
...
@@ -987,6 +1026,17 @@ dependencies = [
"tower-service",
]
[[package]]
name = "eventsource-stream"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
dependencies = [
"futures-core",
"nom",
"pin-project-lite",
]
[[package]]
name = "fastrand"
version = "2.3.0"
...
...
@@ -1149,6 +1199,12 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]]
name = "futures-util"
version = "0.3.31"
...
...
@@ -1194,8 +1250,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
dependencies = [
"cfg-if 1.0.0",
"js-sys",
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
"wasm-bindgen",
]
[[package]]
...
...
@@ -1362,6 +1420,24 @@ dependencies = [
"want",
]
[[package]]
name = "hyper-rustls"
version = "0.27.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2"
dependencies = [
"futures-util",
"http 1.2.0",
"hyper",
"hyper-util",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tower-service",
]
[[package]]
name = "hyper-timeout"
version = "0.5.2"
...
...
@@ -1607,6 +1683,15 @@ version = "0.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb"
[[package]]
name = "instant"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222"
dependencies = [
"cfg-if 1.0.0",
]
[[package]]
name = "iovec"
version = "0.1.4"
...
...
@@ -1616,6 +1701,12 @@ dependencies = [
"libc",
]
[[package]]
name = "ipnet"
version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
[[package]]
name = "is_terminal_polyfill"
version = "1.70.1"
...
...
@@ -1849,6 +1940,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minijinja"
version = "2.7.0"
...
...
@@ -2563,6 +2664,58 @@ dependencies = [
"syn 2.0.98",
]
[[package]]
name = "quinn"
version = "0.11.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef"
dependencies = [
"bytes",
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash",
"rustls",
"socket2",
"thiserror 2.0.11",
"tokio",
"tracing",
]
[[package]]
name = "quinn-proto"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d"
dependencies = [
"bytes",
"getrandom 0.2.15",
"rand",
"ring",
"rustc-hash",
"rustls",
"rustls-pki-types",
"slab",
"thiserror 2.0.11",
"tinyvec",
"tracing",
"web-time",
]
[[package]]
name = "quinn-udp"
version = "0.5.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e46f3055866785f6b92bc6164b76be02ca8f2eb4b002c0354b28cf4c119e5944"
dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2",
"tracing",
"windows-sys 0.59.0",
]
[[package]]
name = "quote"
version = "1.0.38"
...
...
@@ -2697,6 +2850,68 @@ version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "reqwest"
version = "0.12.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da"
dependencies = [
"base64 0.22.1",
"bytes",
"futures-core",
"futures-util",
"http 1.2.0",
"http-body",
"http-body-util",
"hyper",
"hyper-rustls",
"hyper-util",
"ipnet",
"js-sys",
"log",
"mime",
"mime_guess",
"once_cell",
"percent-encoding",
"pin-project-lite",
"quinn",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pemfile",
"rustls-pki-types",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tokio-rustls",
"tokio-util",
"tower 0.5.2",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"wasm-streams",
"web-sys",
"windows-registry",
]
[[package]]
name = "reqwest-eventsource"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde"
dependencies = [
"eventsource-stream",
"futures-core",
"futures-timer",
"mime",
"nom",
"pin-project-lite",
"reqwest",
"thiserror 1.0.69",
]
[[package]]
name = "ring"
version = "0.17.11"
...
...
@@ -2717,6 +2932,12 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc-hash"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
[[package]]
name = "rustc_version"
version = "0.4.1"
...
...
@@ -2793,6 +3014,9 @@ name = "rustls-pki-types"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c"
dependencies = [
"web-time",
]
[[package]]
name = "rustls-webpki"
...
...
@@ -2841,6 +3065,16 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "secrecy"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a"
dependencies = [
"serde",
"zeroize",
]
[[package]]
name = "security-framework"
version = "2.11.1"
...
...
@@ -3170,6 +3404,9 @@ name = "sync_wrapper"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263"
dependencies = [
"futures-core",
]
[[package]]
name = "synstructure"
...
...
@@ -3331,6 +3568,21 @@ dependencies = [
"zerovec",
]
[[package]]
name = "tinyvec"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8"
dependencies = [
"tinyvec_macros",
]
[[package]]
name = "tinyvec_macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokenizers"
version = "0.21.0"
...
...
@@ -3682,6 +3934,7 @@ name = "triton-distributed-llm"
version = "0.2.0"
dependencies = [
"anyhow",
"async-openai",
"async-stream",
"async-trait",
"axum 0.8.1",
...
...
@@ -3810,6 +4063,12 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eeba86d422ce181a719445e51872fa30f1f7413b62becb52e95ec91aa262d85c"
[[package]]
name = "unicase"
version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
[[package]]
name = "unicode-ident"
version = "1.0.17"
...
...
@@ -4027,6 +4286,19 @@ dependencies = [
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-futures"
version = "0.4.50"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61"
dependencies = [
"cfg-if 1.0.0",
"js-sys",
"once_cell",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.100"
...
...
@@ -4059,6 +4331,29 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "wasm-streams"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
dependencies = [
"futures-util",
"js-sys",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]]
name = "web-sys"
version = "0.3.77"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "web-time"
version = "1.1.0"
...
...
@@ -4130,6 +4425,36 @@ dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-registry"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0"
dependencies = [
"windows-result",
"windows-strings",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-result"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e"
dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-strings"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10"
dependencies = [
"windows-result",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-sys"
version = "0.48.0"
...
...
launch/tio/Cargo.lock
View file @
86aff237
...
...
@@ -202,6 +202,31 @@ version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4288f83726785267c6f2ef073a3d83dc3f9b81464e9f99898240cced85fce35a"
[[package]]
name = "async-openai"
version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d126927c78e1562d7e8473008ac8b082318c04d69e3a83e3495a563f8b84a66"
dependencies = [
"backoff",
"base64 0.22.1",
"bytes",
"derive_builder",
"eventsource-stream",
"futures",
"rand",
"reqwest",
"reqwest-eventsource",
"secrecy",
"serde",
"serde_json",
"thiserror 2.0.11",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
]
[[package]]
name = "async-stream"
version = "0.3.6"
...
...
@@ -371,6 +396,20 @@ dependencies = [
"tracing",
]
[[package]]
name = "backoff"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
dependencies = [
"futures-core",
"getrandom 0.2.15",
"instant",
"pin-project-lite",
"rand",
"tokio",
]
[[package]]
name = "backtrace"
version = "0.3.74"
...
...
@@ -1410,6 +1449,17 @@ dependencies = [
"tower-service",
]
[[package]]
name = "eventsource-stream"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
dependencies = [
"futures-core",
"nom",
"pin-project-lite",
]
[[package]]
name = "exr"
version = "1.73.0"
...
...
@@ -1647,6 +1697,12 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]]
name = "futures-util"
version = "0.3.31"
...
...
@@ -1810,8 +1866,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
dependencies = [
"cfg-if 1.0.0",
"js-sys",
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
"wasm-bindgen",
]
[[package]]
...
...
@@ -2022,6 +2080,7 @@ dependencies = [
"hyper",
"hyper-util",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pki-types",
"tokio",
"tokio-rustls",
...
...
@@ -2658,6 +2717,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minijinja"
version = "2.7.0"
...
...
@@ -3731,6 +3800,58 @@ version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3"
[[package]]
name = "quinn"
version = "0.11.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef"
dependencies = [
"bytes",
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash",
"rustls",
"socket2",
"thiserror 2.0.11",
"tokio",
"tracing",
]
[[package]]
name = "quinn-proto"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d"
dependencies = [
"bytes",
"getrandom 0.2.15",
"rand",
"ring",
"rustc-hash",
"rustls",
"rustls-pki-types",
"slab",
"thiserror 2.0.11",
"tinyvec",
"tracing",
"web-time",
]
[[package]]
name = "quinn-udp"
version = "0.5.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e46f3055866785f6b92bc6164b76be02ca8f2eb4b002c0354b28cf4c119e5944"
dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2",
"tracing",
"windows-sys 0.59.0",
]
[[package]]
name = "quote"
version = "1.0.38"
...
...
@@ -3966,11 +4087,16 @@ dependencies = [
"js-sys",
"log",
"mime",
"mime_guess",
"native-tls",
"once_cell",
"percent-encoding",
"pin-project-lite",
"quinn",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pemfile",
"rustls-pki-types",
"serde",
"serde_json",
"serde_urlencoded",
...
...
@@ -3978,15 +4104,34 @@ dependencies = [
"system-configuration",
"tokio",
"tokio-native-tls",
"tokio-rustls",
"tokio-util",
"tower 0.5.2",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"wasm-streams",
"web-sys",
"windows-registry",
]
[[package]]
name = "reqwest-eventsource"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde"
dependencies = [
"eventsource-stream",
"futures-core",
"futures-timer",
"mime",
"nom",
"pin-project-lite",
"reqwest",
"thiserror 1.0.69",
]
[[package]]
name = "ring"
version = "0.17.11"
...
...
@@ -4107,6 +4252,9 @@ name = "rustls-pki-types"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c"
dependencies = [
"web-time",
]
[[package]]
name = "rustls-webpki"
...
...
@@ -4189,6 +4337,16 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "secrecy"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a"
dependencies = [
"serde",
"zeroize",
]
[[package]]
name = "security-framework"
version = "2.11.1"
...
...
@@ -4843,11 +5001,27 @@ dependencies = [
"zerovec",
]
[[package]]
name = "tinyvec"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8"
dependencies = [
"tinyvec_macros",
]
[[package]]
name = "tinyvec_macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tio"
version = "0.1.0"
dependencies = [
"anyhow",
"async-openai",
"async-stream",
"async-trait",
"clap",
...
...
@@ -5276,6 +5450,7 @@ name = "triton-distributed-llm"
version = "0.2.0"
dependencies = [
"anyhow",
"async-openai",
"async-stream",
"async-trait",
"async_zmq",
...
...
@@ -5406,6 +5581,12 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eeba86d422ce181a719445e51872fa30f1f7413b62becb52e95ec91aa262d85c"
[[package]]
name = "unicase"
version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
[[package]]
name = "unicode-ident"
version = "1.0.17"
...
...
@@ -5709,6 +5890,19 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "wasm-streams"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
dependencies = [
"futures-util",
"js-sys",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]]
name = "web-sys"
version = "0.3.77"
...
...
launch/tio/Cargo.toml
View file @
86aff237
...
...
@@ -29,6 +29,7 @@ metal = ["triton-distributed-llm/metal"]
[dependencies]
anyhow
=
"1"
async-openai
=
"0.27.2"
async-stream
=
{
version
=
"0.3"
}
async-trait
=
{
version
=
"0.1"
}
clap
=
{
version
=
"4.5"
,
features
=
[
"derive"
,
"env"
]
}
...
...
launch/tio/src/input/text.rs
View file @
86aff237
...
...
@@ -21,7 +21,6 @@ use std::{
use
triton_distributed_llm
::{
backend
::
Backend
,
preprocessor
::
OpenAIPreprocessor
,
protocols
::
openai
::
chat_completions
::
MessageRole
,
types
::{
openai
::
chat_completions
::{
ChatCompletionRequest
,
ChatCompletionResponseDelta
,
...
...
@@ -38,7 +37,7 @@ use triton_distributed_runtime::{
use
crate
::
EngineConfig
;
/// Max response tokens for each single query. Must be less than model context size.
const
MAX_TOKENS
:
i
32
=
8192
;
const
MAX_TOKENS
:
u
32
=
8192
;
/// Output of `isatty` if the fd is indeed a TTY
const
IS_A_TTY
:
i32
=
1
;
...
...
@@ -96,11 +95,12 @@ pub async fn run(
main_loop
(
cancel_token
,
&
service_name
,
engine
,
inspect_template
)
.await
}
#[allow(deprecated)]
async
fn
main_loop
(
cancel_token
:
CancellationToken
,
service_name
:
&
str
,
engine
:
OpenAIChatCompletionsStreamingEngine
,
inspect_template
:
bool
,
_
inspect_template
:
bool
,
)
->
anyhow
::
Result
<
()
>
{
tracing
::
info!
(
"Ctrl-c to exit"
);
let
theme
=
dialoguer
::
theme
::
ColorfulTheme
::
default
();
...
...
@@ -141,30 +141,31 @@ async fn main_loop(
}
}
};
messages
.push
((
MessageRole
::
user
,
prompt
.clone
()));
// Construct messages
let
user_message
=
async_openai
::
types
::
ChatCompletionRequestMessage
::
User
(
async_openai
::
types
::
ChatCompletionRequestUserMessage
{
content
:
async_openai
::
types
::
ChatCompletionRequestUserMessageContent
::
Text
(
prompt
),
name
:
None
,
},
);
messages
.push
(
user_message
);
// Request
let
mut
req_builder
=
ChatCompletionRequest
::
builder
()
;
req_builder
let
inner
=
async_openai
::
types
::
Create
ChatCompletionRequest
Args
::
default
()
.messages
(
messages
.clone
())
.model
(
service_name
)
.stream
(
true
)
.max_tokens
(
MAX_TOKENS
);
if
inspect_template
{
// This makes the pre-processor ignore stop tokens
req_builder
.min_tokens
(
8192
);
}
for
(
role
,
msg
)
in
&
messages
{
match
role
{
MessageRole
::
user
=>
{
req_builder
.add_user_message
(
msg
);
}
MessageRole
::
assistant
=>
{
req_builder
.add_assistant_message
(
msg
);
}
x
=>
panic!
(
"Only 'user' and 'assistant' messages are supported, not {x}"
),
}
}
let
req
=
req_builder
.build
()
?
;
.max_tokens
(
MAX_TOKENS
)
.build
()
?
;
// TODO We cannot set min_tokens with async-openai
// if inspect_template {
// // This makes the pre-processor ignore stop tokens
// req_builder.min_tokens(8192);
// }
let
req
=
ChatCompletionRequest
{
inner
,
nvext
:
None
};
// Call the model
let
mut
stream
=
engine
.generate
(
Context
::
new
(
req
))
.await
?
;
...
...
@@ -174,7 +175,7 @@ async fn main_loop(
let
mut
assistant_message
=
String
::
new
();
while
let
Some
(
item
)
=
stream
.next
()
.await
{
let
data
=
item
.data
.as_ref
()
.unwrap
();
let
entry
=
data
.choices
.first
();
let
entry
=
data
.
inner.
choices
.first
();
let
chat_comp
=
entry
.as_ref
()
.unwrap
();
if
let
Some
(
c
)
=
&
chat_comp
.delta.content
{
let
_
=
stdout
.write
(
c
.as_bytes
());
...
...
@@ -188,7 +189,23 @@ async fn main_loop(
}
println!
();
messages
.push
((
MessageRole
::
assistant
,
assistant_message
));
let
assistant_content
=
async_openai
::
types
::
ChatCompletionRequestAssistantMessageContent
::
Text
(
assistant_message
,
);
// ALLOW: function_call is deprecated
let
assistant_message
=
async_openai
::
types
::
ChatCompletionRequestMessage
::
Assistant
(
async_openai
::
types
::
ChatCompletionRequestAssistantMessage
{
content
:
Some
(
assistant_content
),
refusal
:
None
,
name
:
None
,
audio
:
None
,
tool_calls
:
None
,
function_call
:
None
,
},
);
messages
.push
(
assistant_message
);
}
println!
();
Ok
(())
...
...
launch/tio/src/output/echo_full.rs
View file @
86aff237
...
...
@@ -18,9 +18,8 @@ use std::{sync::Arc, time::Duration};
use
async_stream
::
stream
;
use
async_trait
::
async_trait
;
use
triton_distributed_llm
::
protocols
::
openai
::
chat_completions
::
FinishReason
;
use
triton_distributed_llm
::
protocols
::
openai
::
chat_completions
::{
ChatCompletionRequest
,
ChatCompletionResponseDelta
,
Content
,
ChatCompletionRequest
,
ChatCompletionResponseDelta
,
};
use
triton_distributed_llm
::
types
::
openai
::
chat_completions
::
OpenAIChatCompletionsStreamingEngine
;
use
triton_distributed_runtime
::
engine
::{
AsyncEngine
,
AsyncEngineContextProvider
,
ResponseStream
};
...
...
@@ -53,25 +52,40 @@ impl
let
(
request
,
context
)
=
incoming_request
.transfer
(());
let
deltas
=
request
.response_generator
();
let
ctx
=
context
.context
();
let
req
=
request
.messages
.into_iter
()
.last
()
.unwrap
();
let
prompt
=
match
req
.content
{
Content
::
Text
(
prompt
)
=>
prompt
,
_
=>
{
anyhow
::
bail!
(
"Invalid request content field, expected Content::Text"
);
let
req
=
request
.inner.messages
.into_iter
()
.last
()
.unwrap
();
let
prompt
=
match
req
{
async_openai
::
types
::
ChatCompletionRequestMessage
::
User
(
user_msg
)
=>
{
match
user_msg
.content
{
async_openai
::
types
::
ChatCompletionRequestUserMessageContent
::
Text
(
prompt
)
=>
{
prompt
}
_
=>
anyhow
::
bail!
(
"Invalid request content field, expected Content::Text"
),
}
}
_
=>
anyhow
::
bail!
(
"Invalid request type, expected User message"
),
};
let
output
=
stream!
{
let
mut
id
=
1
;
for
c
in
prompt
.chars
()
{
// we are returning characters not tokens, so speed up some
tokio
::
time
::
sleep
(
TOKEN_ECHO_DELAY
/
2
)
.await
;
let
delta
=
deltas
.create_choice
(
0
,
Some
(
c
.to_string
()),
None
,
None
);
yield
Annotated
{
id
:
Some
(
id
.to_string
()),
data
:
Some
(
delta
),
event
:
None
,
comment
:
None
};
let
inner
=
deltas
.create_choice
(
0
,
Some
(
c
.to_string
()),
None
,
None
);
let
response
=
ChatCompletionResponseDelta
{
inner
,
};
yield
Annotated
{
id
:
Some
(
id
.to_string
()),
data
:
Some
(
response
),
event
:
None
,
comment
:
None
};
id
+=
1
;
}
let
stop_delta
=
deltas
.create_choice
(
0
,
None
,
Some
(
FinishReason
::
stop
),
None
);
yield
Annotated
{
id
:
Some
(
id
.to_string
()),
data
:
Some
(
stop_delta
),
event
:
None
,
comment
:
None
};
let
inner
=
deltas
.create_choice
(
0
,
None
,
Some
(
async_openai
::
types
::
FinishReason
::
Stop
),
None
);
let
response
=
ChatCompletionResponseDelta
{
inner
,
};
yield
Annotated
{
id
:
Some
(
id
.to_string
()),
data
:
Some
(
response
),
event
:
None
,
comment
:
None
};
};
Ok
(
ResponseStream
::
new
(
Box
::
pin
(
output
),
ctx
))
}
}
lib/bindings/c/Cargo.lock
View file @
86aff237
...
...
@@ -151,6 +151,31 @@ version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4288f83726785267c6f2ef073a3d83dc3f9b81464e9f99898240cced85fce35a"
[[package]]
name = "async-openai"
version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d126927c78e1562d7e8473008ac8b082318c04d69e3a83e3495a563f8b84a66"
dependencies = [
"backoff",
"base64 0.22.1",
"bytes",
"derive_builder",
"eventsource-stream",
"futures",
"rand",
"reqwest",
"reqwest-eventsource",
"secrecy",
"serde",
"serde_json",
"thiserror 2.0.11",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
]
[[package]]
name = "async-stream"
version = "0.3.6"
...
...
@@ -320,6 +345,20 @@ dependencies = [
"tracing",
]
[[package]]
name = "backoff"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
dependencies = [
"futures-core",
"getrandom 0.2.15",
"instant",
"pin-project-lite",
"rand",
"tokio",
]
[[package]]
name = "backtrace"
version = "0.3.74"
...
...
@@ -986,6 +1025,17 @@ dependencies = [
"tower-service",
]
[[package]]
name = "eventsource-stream"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
dependencies = [
"futures-core",
"nom",
"pin-project-lite",
]
[[package]]
name = "fastrand"
version = "2.3.0"
...
...
@@ -1148,6 +1198,12 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]]
name = "futures-util"
version = "0.3.31"
...
...
@@ -1193,8 +1249,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
dependencies = [
"cfg-if 1.0.0",
"js-sys",
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
"wasm-bindgen",
]
[[package]]
...
...
@@ -1348,6 +1406,24 @@ dependencies = [
"want",
]
[[package]]
name = "hyper-rustls"
version = "0.27.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2"
dependencies = [
"futures-util",
"http",
"hyper",
"hyper-util",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tower-service",
]
[[package]]
name = "hyper-timeout"
version = "0.5.2"
...
...
@@ -1593,6 +1669,15 @@ version = "0.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb"
[[package]]
name = "instant"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222"
dependencies = [
"cfg-if 1.0.0",
]
[[package]]
name = "iovec"
version = "0.1.4"
...
...
@@ -1602,6 +1687,12 @@ dependencies = [
"libc",
]
[[package]]
name = "ipnet"
version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
[[package]]
name = "is_terminal_polyfill"
version = "1.70.1"
...
...
@@ -1842,6 +1933,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minijinja"
version = "2.7.0"
...
...
@@ -2545,6 +2646,58 @@ dependencies = [
"syn 2.0.96",
]
[[package]]
name = "quinn"
version = "0.11.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef"
dependencies = [
"bytes",
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash",
"rustls",
"socket2",
"thiserror 2.0.11",
"tokio",
"tracing",
]
[[package]]
name = "quinn-proto"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d"
dependencies = [
"bytes",
"getrandom 0.2.15",
"rand",
"ring",
"rustc-hash",
"rustls",
"rustls-pki-types",
"slab",
"thiserror 2.0.11",
"tinyvec",
"tracing",
"web-time",
]
[[package]]
name = "quinn-udp"
version = "0.5.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e46f3055866785f6b92bc6164b76be02ca8f2eb4b002c0354b28cf4c119e5944"
dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2",
"tracing",
"windows-sys 0.59.0",
]
[[package]]
name = "quote"
version = "1.0.38"
...
...
@@ -2679,6 +2832,68 @@ version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "reqwest"
version = "0.12.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da"
dependencies = [
"base64 0.22.1",
"bytes",
"futures-core",
"futures-util",
"http",
"http-body",
"http-body-util",
"hyper",
"hyper-rustls",
"hyper-util",
"ipnet",
"js-sys",
"log",
"mime",
"mime_guess",
"once_cell",
"percent-encoding",
"pin-project-lite",
"quinn",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pemfile",
"rustls-pki-types",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tokio-rustls",
"tokio-util",
"tower 0.5.2",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"wasm-streams",
"web-sys",
"windows-registry",
]
[[package]]
name = "reqwest-eventsource"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde"
dependencies = [
"eventsource-stream",
"futures-core",
"futures-timer",
"mime",
"nom",
"pin-project-lite",
"reqwest",
"thiserror 1.0.69",
]
[[package]]
name = "ring"
version = "0.17.8"
...
...
@@ -2700,6 +2915,12 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc-hash"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
[[package]]
name = "rustc_version"
version = "0.4.1"
...
...
@@ -2776,6 +2997,9 @@ name = "rustls-pki-types"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c"
dependencies = [
"web-time",
]
[[package]]
name = "rustls-webpki"
...
...
@@ -2824,6 +3048,16 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "secrecy"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a"
dependencies = [
"serde",
"zeroize",
]
[[package]]
name = "security-framework"
version = "2.11.1"
...
...
@@ -3148,6 +3382,9 @@ name = "sync_wrapper"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263"
dependencies = [
"futures-core",
]
[[package]]
name = "synstructure"
...
...
@@ -3286,6 +3523,21 @@ dependencies = [
"zerovec",
]
[[package]]
name = "tinyvec"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8"
dependencies = [
"tinyvec_macros",
]
[[package]]
name = "tinyvec_macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokenizers"
version = "0.21.0"
...
...
@@ -3637,6 +3889,7 @@ name = "triton-distributed-llm"
version = "0.2.0"
dependencies = [
"anyhow",
"async-openai",
"async-stream",
"async-trait",
"axum 0.8.1",
...
...
@@ -3765,6 +4018,12 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eeba86d422ce181a719445e51872fa30f1f7413b62becb52e95ec91aa262d85c"
[[package]]
name = "unicase"
version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
[[package]]
name = "unicode-ident"
version = "1.0.14"
...
...
@@ -3982,6 +4241,19 @@ dependencies = [
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-futures"
version = "0.4.50"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61"
dependencies = [
"cfg-if 1.0.0",
"js-sys",
"once_cell",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.100"
...
...
@@ -4014,6 +4286,29 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "wasm-streams"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
dependencies = [
"futures-util",
"js-sys",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]]
name = "web-sys"
version = "0.3.77"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "web-time"
version = "1.1.0"
...
...
@@ -4085,6 +4380,36 @@ dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-registry"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0"
dependencies = [
"windows-result",
"windows-strings",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-result"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e"
dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-strings"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10"
dependencies = [
"windows-result",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-sys"
version = "0.48.0"
...
...
lib/bindings/python/Cargo.lock
View file @
86aff237
...
...
@@ -163,6 +163,31 @@ version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4288f83726785267c6f2ef073a3d83dc3f9b81464e9f99898240cced85fce35a"
[[package]]
name = "async-openai"
version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d126927c78e1562d7e8473008ac8b082318c04d69e3a83e3495a563f8b84a66"
dependencies = [
"backoff",
"base64 0.22.1",
"bytes",
"derive_builder",
"eventsource-stream",
"futures",
"rand",
"reqwest",
"reqwest-eventsource",
"secrecy",
"serde",
"serde_json",
"thiserror 2.0.11",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
]
[[package]]
name = "async-stream"
version = "0.3.6"
...
...
@@ -332,6 +357,20 @@ dependencies = [
"tracing",
]
[[package]]
name = "backoff"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
dependencies = [
"futures-core",
"getrandom 0.2.15",
"instant",
"pin-project-lite",
"rand",
"tokio",
]
[[package]]
name = "backtrace"
version = "0.3.74"
...
...
@@ -1009,6 +1048,17 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "eventsource-stream"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
dependencies = [
"futures-core",
"nom",
"pin-project-lite",
]
[[package]]
name = "fastrand"
version = "2.3.0"
...
...
@@ -1171,6 +1221,12 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]]
name = "futures-util"
version = "0.3.31"
...
...
@@ -1216,8 +1272,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
dependencies = [
"cfg-if 1.0.0",
"js-sys",
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
"wasm-bindgen",
]
[[package]]
...
...
@@ -1365,6 +1423,24 @@ dependencies = [
"want",
]
[[package]]
name = "hyper-rustls"
version = "0.27.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2"
dependencies = [
"futures-util",
"http",
"hyper",
"hyper-util",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tower-service",
]
[[package]]
name = "hyper-timeout"
version = "0.5.2"
...
...
@@ -1610,6 +1686,15 @@ version = "0.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb"
[[package]]
name = "instant"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222"
dependencies = [
"cfg-if 1.0.0",
]
[[package]]
name = "inventory"
version = "0.3.19"
...
...
@@ -1628,6 +1713,12 @@ dependencies = [
"libc",
]
[[package]]
name = "ipnet"
version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
[[package]]
name = "is_terminal_polyfill"
version = "1.70.1"
...
...
@@ -1847,6 +1938,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minijinja"
version = "2.7.0"
...
...
@@ -2594,6 +2695,58 @@ dependencies = [
"serde",
]
[[package]]
name = "quinn"
version = "0.11.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef"
dependencies = [
"bytes",
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash",
"rustls",
"socket2",
"thiserror 2.0.11",
"tokio",
"tracing",
]
[[package]]
name = "quinn-proto"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d"
dependencies = [
"bytes",
"getrandom 0.2.15",
"rand",
"ring",
"rustc-hash",
"rustls",
"rustls-pki-types",
"slab",
"thiserror 2.0.11",
"tinyvec",
"tracing",
"web-time",
]
[[package]]
name = "quinn-udp"
version = "0.5.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e46f3055866785f6b92bc6164b76be02ca8f2eb4b002c0354b28cf4c119e5944"
dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2",
"tracing",
"windows-sys 0.59.0",
]
[[package]]
name = "quote"
version = "1.0.38"
...
...
@@ -2728,6 +2881,68 @@ version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "reqwest"
version = "0.12.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da"
dependencies = [
"base64 0.22.1",
"bytes",
"futures-core",
"futures-util",
"http",
"http-body",
"http-body-util",
"hyper",
"hyper-rustls",
"hyper-util",
"ipnet",
"js-sys",
"log",
"mime",
"mime_guess",
"once_cell",
"percent-encoding",
"pin-project-lite",
"quinn",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pemfile",
"rustls-pki-types",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tokio-rustls",
"tokio-util",
"tower 0.5.2",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"wasm-streams",
"web-sys",
"windows-registry",
]
[[package]]
name = "reqwest-eventsource"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde"
dependencies = [
"eventsource-stream",
"futures-core",
"futures-timer",
"mime",
"nom",
"pin-project-lite",
"reqwest",
"thiserror 1.0.69",
]
[[package]]
name = "ring"
version = "0.17.8"
...
...
@@ -2749,6 +2964,12 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc-hash"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
[[package]]
name = "rustc_version"
version = "0.4.1"
...
...
@@ -2825,6 +3046,9 @@ name = "rustls-pki-types"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c"
dependencies = [
"web-time",
]
[[package]]
name = "rustls-webpki"
...
...
@@ -2873,6 +3097,16 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "secrecy"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a"
dependencies = [
"serde",
"zeroize",
]
[[package]]
name = "security-framework"
version = "2.11.1"
...
...
@@ -3197,6 +3431,9 @@ name = "sync_wrapper"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263"
dependencies = [
"futures-core",
]
[[package]]
name = "synstructure"
...
...
@@ -3335,6 +3572,21 @@ dependencies = [
"zerovec",
]
[[package]]
name = "tinyvec"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8"
dependencies = [
"tinyvec_macros",
]
[[package]]
name = "tinyvec_macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokenizers"
version = "0.21.0"
...
...
@@ -3686,6 +3938,7 @@ name = "triton-distributed-llm"
version = "0.2.0"
dependencies = [
"anyhow",
"async-openai",
"async-stream",
"async-trait",
"axum 0.8.1",
...
...
@@ -3834,6 +4087,12 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eeba86d422ce181a719445e51872fa30f1f7413b62becb52e95ec91aa262d85c"
[[package]]
name = "unicase"
version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
[[package]]
name = "unicode-ident"
version = "1.0.16"
...
...
@@ -4051,6 +4310,19 @@ dependencies = [
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-futures"
version = "0.4.50"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61"
dependencies = [
"cfg-if 1.0.0",
"js-sys",
"once_cell",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.100"
...
...
@@ -4083,6 +4355,29 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "wasm-streams"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
dependencies = [
"futures-util",
"js-sys",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]]
name = "web-sys"
version = "0.3.77"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "web-time"
version = "1.1.0"
...
...
@@ -4154,6 +4449,36 @@ dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-registry"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0"
dependencies = [
"windows-result",
"windows-strings",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-result"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e"
dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-strings"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10"
dependencies = [
"windows-result",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-sys"
version = "0.48.0"
...
...
lib/llm/Cargo.lock
View file @
86aff237
...
...
@@ -129,9 +129,9 @@ dependencies = [
[[package]]
name = "anyhow"
version = "1.0.9
5
"
version = "1.0.9
6
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b0
4"
checksum = "
6b964d184e89d9b6b67dd2715bc8e74cf3107fb2b529990c90cf517326150bf
4"
[[package]]
name = "arbitrary"
...
...
@@ -202,6 +202,31 @@ version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4288f83726785267c6f2ef073a3d83dc3f9b81464e9f99898240cced85fce35a"
[[package]]
name = "async-openai"
version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d126927c78e1562d7e8473008ac8b082318c04d69e3a83e3495a563f8b84a66"
dependencies = [
"backoff",
"base64 0.22.1",
"bytes",
"derive_builder",
"eventsource-stream",
"futures",
"rand",
"reqwest",
"reqwest-eventsource",
"secrecy",
"serde",
"serde_json",
"thiserror 2.0.11",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
]
[[package]]
name = "async-stream"
version = "0.3.6"
...
...
@@ -371,6 +396,20 @@ dependencies = [
"tracing",
]
[[package]]
name = "backoff"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
dependencies = [
"futures-core",
"getrandom 0.2.15",
"instant",
"pin-project-lite",
"rand",
"tokio",
]
[[package]]
name = "backtrace"
version = "0.3.74"
...
...
@@ -460,15 +499,16 @@ checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36"
[[package]]
name = "blake3"
version = "1.
5.5
"
version = "1.
6.0
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
b8ee0c1824c4dea5b5f81736aff91bae041d2c07ee1192bec91054e10e3e601e
"
checksum = "
1230237285e3e10cde447185e8975408ae24deaa67205ce684805c25bc0c7937
"
dependencies = [
"arrayref",
"arrayvec",
"cc",
"cfg-if 1.0.0",
"constant_time_eq",
"memmap2",
]
[[package]]
...
...
@@ -662,9 +702,9 @@ dependencies = [
[[package]]
name = "cc"
version = "1.2.1
2
"
version = "1.2.1
5
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
755717a7de9ec452bf7f3f1a3099085deabd7f2962b861dae91ecd7a365903d2
"
checksum = "
c736e259eea577f443d5c86c304f9f4ae0295c43f3ba05c21f1d66b5f06001af
"
dependencies = [
"jobserver",
"libc",
...
...
@@ -730,9 +770,9 @@ dependencies = [
[[package]]
name = "clap"
version = "4.5.
29
"
version = "4.5.
31
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
8acebd8ad879283633b343856142139f2da2317c96b05b4dd6181c61e2480184
"
checksum = "
027bb0d98429ae334a8698531da7077bdf906419543a35a55c2cb1b66437d767
"
dependencies = [
"clap_builder",
"clap_derive",
...
...
@@ -740,9 +780,9 @@ dependencies = [
[[package]]
name = "clap_builder"
version = "4.5.
29
"
version = "4.5.
31
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
f6ba32cbda51c7e1dfd49acc1457ba1a7dec5b64fe360e828acb13ca8dc9c2f9
"
checksum = "
5589e0cba072e0f3d23791efac0fd8627b49c829c196a492e88168e6a669d863
"
dependencies = [
"anstream",
"anstyle",
...
...
@@ -990,9 +1030,9 @@ dependencies = [
[[package]]
name = "cudarc"
version = "0.13.
4
"
version = "0.13.
7
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
3b68d7c284d40d96a4251330ab583c2718b412f4fc53239d295b3a1f8735f426
"
checksum = "
4e29ce3bfa797c1183053ceb496316203ef561c183941c3c181500d9ade6daf4
"
dependencies = [
"half",
"libloading",
...
...
@@ -1096,9 +1136,9 @@ dependencies = [
[[package]]
name = "data-encoding"
version = "2.
7
.0"
version = "2.
8
.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
0e60eed09d8c01d3cee5b7d30acb059b76614c918fa0f992e0dd6eeb10daad6f
"
checksum = "
575f75dfd25738df5b91b8e43e14d44bda14637a58fae779fd2b064f8bf3e010
"
[[package]]
name = "defmac"
...
...
@@ -1320,9 +1360,9 @@ dependencies = [
[[package]]
name = "either"
version = "1.1
3
.0"
version = "1.1
4
.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0
"
checksum = "
b7914353092ddf589ad78f25c5c1c21b7f80b0ff8621e7c814c3485b5306da9d
"
dependencies = [
"serde",
]
...
...
@@ -1382,9 +1422,9 @@ dependencies = [
[[package]]
name = "equivalent"
version = "1.0.
1
"
version = "1.0.
2
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5
"
checksum = "
877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f
"
[[package]]
name = "erased-serde"
...
...
@@ -1422,7 +1462,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc0452bcc559431b16f472b7ab86e2f9ccd5f3c2da3795afbd6b773665e047fe"
dependencies = [
"http",
"prost 0.13.
4
",
"prost 0.13.
5
",
"tokio",
"tokio-stream",
"tonic",
...
...
@@ -1431,6 +1471,17 @@ dependencies = [
"tower-service",
]
[[package]]
name = "eventsource-stream"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
dependencies = [
"futures-core",
"nom",
"pin-project-lite",
]
[[package]]
name = "exr"
version = "1.73.0"
...
...
@@ -1486,15 +1537,15 @@ dependencies = [
[[package]]
name = "fixedbitset"
version = "0.
4.2
"
version = "0.
5.7
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80
"
checksum = "
1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99
"
[[package]]
name = "flate2"
version = "1.
0.35
"
version = "1.
1.0
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07
c"
checksum = "
11faaf5a5236997af9848be0bef4db95824b1d534ebc64d0f0c6cf3e67bd38d
c"
dependencies = [
"crc32fast",
"miniz_oxide",
...
...
@@ -1892,9 +1943,9 @@ dependencies = [
[[package]]
name = "h2"
version = "0.4.
7
"
version = "0.4.
8
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
ccae279728d634d083c00f6099cb58f01cc99c145b84b8be2f6c74618d79922e
"
checksum = "
5017294ff4bb30944501348f6f8e42e6ad28f42c8bbef7a74029aff064a4e3c2
"
dependencies = [
"atomic-waker",
"bytes",
...
...
@@ -2086,6 +2137,7 @@ dependencies = [
"hyper",
"hyper-util",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pki-types",
"tokio",
"tokio-rustls",
...
...
@@ -2551,9 +2603,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8"
[[package]]
name = "libc"
version = "0.2.1
69
"
version = "0.2.1
70
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a
"
checksum = "
875b3680cb2f8f71bdcf9a30f38d48282f5d3c95cbf9b3fa57269bb5d5c06828
"
[[package]]
name = "libloading"
...
...
@@ -2642,9 +2694,9 @@ dependencies = [
[[package]]
name = "log"
version = "0.4.2
5
"
version = "0.4.2
6
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15
e29f
ff3f6c077fd9cd9f
"
checksum = "
30bde2b3dc3671ae49d8
e2
e
9f
044c7c005836e7a023ee57cffa25ab82764bb9e
"
[[package]]
name = "lrtable"
...
...
@@ -2757,6 +2809,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minijinja"
version = "2.7.0"
...
...
@@ -2787,9 +2849,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]]
name = "miniz_oxide"
version = "0.8.
3
"
version = "0.8.
5
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
b8402cab7aefae129c6977bb0ff1b8fd9a04
eb
5
b5
1efc50a70bea51cda0c7924
"
checksum = "
8e3e04d
eb
b
b5
9698c15bacbb6d93584a8c0ca9cc3213cb423d31f760d8843ce5
"
dependencies = [
"adler2",
"simd-adler32",
...
...
@@ -2937,7 +2999,7 @@ dependencies = [
"tqdm",
"tracing",
"tracing-subscriber",
"uuid 1.1
3.1
",
"uuid 1.1
4.0
",
"variantly",
"vob",
]
...
...
@@ -3021,9 +3083,9 @@ checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03"
[[package]]
name = "native-tls"
version = "0.2.1
3
"
version = "0.2.1
4
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
0dab59f8e050d5df8e4dd87d9206fb6f65a483e20ac9
fda36
5ade4fab353196c
"
checksum = "
87de3442987e9dbec73158d5c715e7ad9072
fda
9
36
bb03d19d7fa10e00520f0e
"
dependencies = [
"libc",
"log",
...
...
@@ -3289,9 +3351,9 @@ dependencies = [
[[package]]
name = "once_cell"
version = "1.20.
2
"
version = "1.20.
3
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775
"
checksum = "
945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e
"
[[package]]
name = "onig"
...
...
@@ -3317,9 +3379,9 @@ dependencies = [
[[package]]
name = "openssl"
version = "0.10.7
0
"
version = "0.10.7
1
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
61cfb4e166a8bb8c9b55c500bc2308550148ece889be90f609377e58140f42c6
"
checksum = "
5e14130c6a98cd258fdcb0fb6d744152343ff729cbfcb28c656a9d12b999fbcd
"
dependencies = [
"bitflags 2.8.0",
"cfg-if 1.0.0",
...
...
@@ -3349,9 +3411,9 @@ checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
[[package]]
name = "openssl-sys"
version = "0.9.10
5
"
version = "0.9.10
6
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b
22d5b84be05a8d6947c7cb71f7c849aa0f112acd4bf51c2a7c1c988ac0a9dc
"
checksum = "8b
b61ea9811cc39e3c2069f40b8b8e2e70d8569b361f879786cc7ed48b777cdd
"
dependencies = [
"cc",
"libc",
...
...
@@ -3495,9 +3557,9 @@ dependencies = [
[[package]]
name = "petgraph"
version = "0.
6.5
"
version = "0.
7.1
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db
"
checksum = "
3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772
"
dependencies = [
"fixedbitset",
"indexmap 2.7.1",
...
...
@@ -3566,9 +3628,9 @@ dependencies = [
[[package]]
name = "portable-atomic"
version = "1.1
0
.0"
version = "1.1
1
.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6
"
checksum = "
350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e
"
[[package]]
name = "powerfmt"
...
...
@@ -3695,28 +3757,28 @@ dependencies = [
[[package]]
name = "prost"
version = "0.13.
4
"
version = "0.13.
5
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2
c0fef6c4230e4ccf618a35c59d7ede15dea37de8427500f50aff708806e42ec
"
checksum = "2
796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5
"
dependencies = [
"bytes",
"prost-derive 0.13.
4
",
"prost-derive 0.13.
5
",
]
[[package]]
name = "prost-build"
version = "0.13.
4
"
version = "0.13.
5
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
d0f3e5beed80eb580c68e2c600937ac2c4eedabdfd5ef1e5b7ea4f3fba84497b
"
checksum = "
be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf
"
dependencies = [
"heck 0.5.0",
"itertools 0.1
3
.0",
"itertools 0.1
4
.0",
"log",
"multimap",
"once_cell",
"petgraph",
"prettyplease",
"prost 0.13.
4
",
"prost 0.13.
5
",
"prost-types",
"regex",
"syn 2.0.98",
...
...
@@ -3738,12 +3800,12 @@ dependencies = [
[[package]]
name = "prost-derive"
version = "0.13.
4
"
version = "0.13.
5
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
157c5a9d7ea5c2ed2d9fb8f495b64759f7816c7eaea54ba3978f0d63000162e3
"
checksum = "
8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d
"
dependencies = [
"anyhow",
"itertools 0.1
3
.0",
"itertools 0.1
4
.0",
"proc-macro2",
"quote",
"syn 2.0.98",
...
...
@@ -3751,11 +3813,11 @@ dependencies = [
[[package]]
name = "prost-types"
version = "0.13.
4
"
version = "0.13.
5
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
cc2f1e56baa61e93533aebc21af4d2134b70f66275e0fcdf3cbe43d77ff7e8fc
"
checksum = "
52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16
"
dependencies = [
"prost 0.13.
4
",
"prost 0.13.
5
",
]
[[package]]
...
...
@@ -3900,9 +3962,9 @@ dependencies = [
[[package]]
name = "quinn-udp"
version = "0.5.
9
"
version = "0.5.
10
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
1c40286217b4ba3a71d644d752e6a0b71f13f1b6a2c5311acfcbe0c2418ed90
4"
checksum = "
e46f3055866785f6b92bc6164b76be02ca8f2eb4b002c0354b28cf4c119e594
4"
dependencies = [
"cfg_aliases",
"libc",
...
...
@@ -4037,9 +4099,9 @@ checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430"
[[package]]
name = "redox_syscall"
version = "0.5.
8
"
version = "0.5.
9
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834
"
checksum = "
82b568323e98e49e2a0899dcee453dd679fae22d69adf9b11dd508d1549b7e2f
"
dependencies = [
"bitflags 2.8.0",
]
...
...
@@ -4162,12 +4224,14 @@ dependencies = [
"js-sys",
"log",
"mime",
"mime_guess",
"native-tls",
"once_cell",
"percent-encoding",
"pin-project-lite",
"quinn",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pemfile",
"rustls-pki-types",
"serde",
...
...
@@ -4190,17 +4254,32 @@ dependencies = [
"windows-registry",
]
[[package]]
name = "reqwest-eventsource"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde"
dependencies = [
"eventsource-stream",
"futures-core",
"futures-timer",
"mime",
"nom",
"pin-project-lite",
"reqwest",
"thiserror 1.0.69",
]
[[package]]
name = "ring"
version = "0.17.
8
"
version = "0.17.
11
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
c17fa4cb658e3583423e915b9f3acc01cceaee1
86
0
e3
3d59ebae66adc3a2dc0d
"
checksum = "
da5349ae27d3887ca812fb375b45a4fbb36d8d12d2df394968cd
86e3
5683fe73
"
dependencies = [
"cc",
"cfg-if 1.0.0",
"getrandom 0.2.15",
"libc",
"spin",
"untrusted",
"windows-sys 0.52.0",
]
...
...
@@ -4242,9 +4321,9 @@ checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc-hash"
version = "2.1.
0
"
version = "2.1.
1
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
c7fb8039b3032c191086b10f11f319a6e99e1e82889c5cc6046f515c9db1d497
"
checksum = "
357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d
"
[[package]]
name = "rustc_version"
...
...
@@ -4270,9 +4349,9 @@ dependencies = [
[[package]]
name = "rustls"
version = "0.23.2
2
"
version = "0.23.2
3
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
9fb9263ab4eb695e42321db096e3b8fbd715a59b154d5c88d82db2175b681ba7
"
checksum = "
47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395
"
dependencies = [
"log",
"once_cell",
...
...
@@ -4391,9 +4470,9 @@ dependencies = [
[[package]]
name = "schemars"
version = "0.8.2
1
"
version = "0.8.2
2
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
09c024468a378b7e36765cd36702b7a90cc3cba11654f6685c8f233408e89e92
"
checksum = "
3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615
"
dependencies = [
"dyn-clone",
"schemars_derive",
...
...
@@ -4403,9 +4482,9 @@ dependencies = [
[[package]]
name = "schemars_derive"
version = "0.8.2
1
"
version = "0.8.2
2
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
b1eee588578aff73f856ab961cd2f79e36bc45d7ded33a7562adba4667aecc0e
"
checksum = "
32e265784ad618884abaea0600a9adf15393368d840e0222d101a072f3f7534d
"
dependencies = [
"proc-macro2",
"quote",
...
...
@@ -4419,6 +4498,16 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "secrecy"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a"
dependencies = [
"serde",
"zeroize",
]
[[package]]
name = "security-framework"
version = "2.11.1"
...
...
@@ -4504,9 +4593,9 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4"
[[package]]
name = "serde"
version = "1.0.21
7
"
version = "1.0.21
8
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c7
0"
checksum = "
e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df6
0"
dependencies = [
"serde_derive",
]
...
...
@@ -4526,9 +4615,9 @@ dependencies = [
[[package]]
name = "serde_derive"
version = "1.0.21
7
"
version = "1.0.21
8
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0
"
checksum = "
f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b
"
dependencies = [
"proc-macro2",
"quote",
...
...
@@ -4548,9 +4637,9 @@ dependencies = [
[[package]]
name = "serde_json"
version = "1.0.13
8
"
version = "1.0.13
9
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949
"
checksum = "
44f86c3acccc9c65b153fe1b85a3be07fe5515274ec9f0653b4a0875731c72a6
"
dependencies = [
"indexmap 2.7.1",
"itoa",
...
...
@@ -4733,9 +4822,9 @@ dependencies = [
[[package]]
name = "smallvec"
version = "1.1
3.2
"
version = "1.1
4.0
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67
"
checksum = "
7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd
"
[[package]]
name = "socket2"
...
...
@@ -4770,12 +4859,6 @@ dependencies = [
"vob",
]
[[package]]
name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
[[package]]
name = "spki"
version = "0.7.3"
...
...
@@ -5272,9 +5355,9 @@ dependencies = [
[[package]]
name = "toktrie"
version = "0.6.2
8
"
version = "0.6.2
9
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
5f9c32a81c3faff7dde7909b471a5c39970e684b7fb88994a0fe4d9a3fb3a2b1
"
checksum = "
1d6b6bcdb5d6345ffe9504e26906dd61c118c75191355f6219ada2854fe1421b
"
dependencies = [
"anyhow",
"bytemuck",
...
...
@@ -5299,16 +5382,16 @@ dependencies = [
[[package]]
name = "toktrie_hf_tokenizers"
version = "0.6.2
8
"
version = "0.6.2
9
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
9e43a051aef243334de35fe3ba70060e11aa373eeb193cc98cce37383d312002
"
checksum = "
f33000768c9ef47791df9a7da0b2bcd06f758c93dac13af1b9df25a84be0a204
"
dependencies = [
"anyhow",
"log",
"serde",
"serde_json",
"tokenizers",
"toktrie 0.6.2
8
",
"toktrie 0.6.2
9
",
]
[[package]]
...
...
@@ -5334,9 +5417,9 @@ dependencies = [
[[package]]
name = "toml_edit"
version = "0.22.2
3
"
version = "0.22.2
4
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
02a8b472d1a3d7c18e2d61a489aee3453fd9031c33e4f55bd533f4a7adca1bee
"
checksum = "
17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474
"
dependencies = [
"indexmap 2.7.1",
"serde",
...
...
@@ -5365,7 +5448,7 @@ dependencies = [
"hyper-util",
"percent-encoding",
"pin-project",
"prost 0.13.
4
",
"prost 0.13.
5
",
"socket2",
"tokio",
"tokio-stream",
...
...
@@ -5529,6 +5612,7 @@ name = "triton-distributed-llm"
version = "0.2.0"
dependencies = [
"anyhow",
"async-openai",
"async-stream",
"async-trait",
"async_zmq",
...
...
@@ -5568,12 +5652,12 @@ dependencies = [
"tokio",
"tokio-stream",
"tokio-util",
"toktrie 0.6.2
8
",
"toktrie_hf_tokenizers 0.6.2
8
",
"toktrie 0.6.2
9
",
"toktrie_hf_tokenizers 0.6.2
9
",
"tracing",
"triton-distributed-runtime",
"unicode-segmentation",
"uuid 1.1
3.1
",
"uuid 1.1
4.0
",
"validator",
"xxhash-rust",
]
...
...
@@ -5617,7 +5701,7 @@ dependencies = [
"tokio-util",
"tracing",
"tracing-subscriber",
"uuid 1.1
3.1
",
"uuid 1.1
4.0
",
"validator",
"xxhash-rust",
]
...
...
@@ -5647,9 +5731,9 @@ checksum = "0e13db2e0ccd5e14a544e8a246ba2312cd25223f616442d7f2cb0e3db614236e"
[[package]]
name = "typenum"
version = "1.1
7
.0"
version = "1.1
8
.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825
"
checksum = "
1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f
"
[[package]]
name = "ucd-trie"
...
...
@@ -5678,11 +5762,17 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eeba86d422ce181a719445e51872fa30f1f7413b62becb52e95ec91aa262d85c"
[[package]]
name = "unicase"
version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
[[package]]
name = "unicode-ident"
version = "1.0.1
6
"
version = "1.0.1
7
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034
"
checksum = "
00e2473a93778eb0bad35909dff6a10d28e63f792f16ed15e404fca9d5eeedbe
"
[[package]]
name = "unicode-normalization-alignments"
...
...
@@ -5789,9 +5879,9 @@ dependencies = [
[[package]]
name = "uuid"
version = "1.1
3.1
"
version = "1.1
4.0
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
ced87ca4be083373936a67f8de945faa23b6b42384bd5b64434850802c6dccd0
"
checksum = "
93d59ca99a559661b96bf898d8fce28ed87935fd2bea9f05983c1464dd6c71b1
"
dependencies = [
"getrandom 0.3.1",
"serde",
...
...
@@ -6280,9 +6370,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "winnow"
version = "0.7.
1
"
version = "0.7.
3
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
86e376c75f4f43f44db463cf729e0d3acbf954d13e22c51e26e4c264b4ab545f
"
checksum = "
0e7f4ea97f6f78012141bcdb6a216b2609f0979ada50b20ca5b52dde2eac2bb1
"
dependencies = [
"memchr",
]
...
...
lib/llm/Cargo.toml
View file @
86aff237
...
...
@@ -81,6 +81,7 @@ uuid = { workspace = true }
xxhash-rust
=
{
workspace
=
true
}
strum
=
{
workspace
=
true
}
async-openai
=
"0.27.2"
blake3
=
"1"
regex
=
"1"
...
...
lib/llm/src/engines/mistralrs.rs
View file @
86aff237
...
...
@@ -15,6 +15,7 @@
use
std
::{
cmp
::
min
,
num
::
NonZero
,
path
::
Path
,
sync
::
Arc
};
use
async_openai
::
types
::
FinishReason
;
use
async_stream
::
stream
;
use
async_trait
::
async_trait
;
use
either
::
Either
;
...
...
@@ -33,8 +34,7 @@ use triton_distributed_runtime::pipeline::{Error, ManyOut, SingleIn};
use
triton_distributed_runtime
::
protocols
::
annotated
::
Annotated
;
use
crate
::
protocols
::
openai
::
chat_completions
::{
ChatCompletionChoiceDelta
,
ChatCompletionContent
,
ChatCompletionRequest
,
ChatCompletionResponseDelta
,
Content
,
FinishReason
,
MessageRole
,
ChatCompletionRequest
,
ChatCompletionResponseDelta
,
};
use
crate
::
types
::
openai
::
chat_completions
::
OpenAIChatCompletionsStreamingEngine
;
...
...
@@ -174,11 +174,14 @@ impl
let
(
tx
,
mut
rx
)
=
channel
(
10_000
);
let
maybe_tok
=
self
.pipeline
.lock
()
.await
.tokenizer
();
let
mut
prompt_tokens
=
0
;
let
mut
prompt_tokens
=
0
i32
;
let
mut
messages
=
vec!
[];
for
m
in
request
.messages
{
let
content
=
match
m
.content
{
Content
::
Text
(
prompt
)
=>
{
for
m
in
request
.inner.messages
{
let
async_openai
::
types
::
ChatCompletionRequestMessage
::
User
(
inner_m
)
=
m
else
{
continue
;
};
let
content
=
match
inner_m
.content
{
async_openai
::
types
::
ChatCompletionRequestUserMessageContent
::
Text
(
prompt
)
=>
{
if
let
Some
(
tok
)
=
maybe_tok
.as_ref
()
{
prompt_tokens
=
tok
.encode
(
prompt
.clone
(),
false
)
...
...
@@ -187,12 +190,12 @@ impl
}
prompt
}
Content
::
ImageUrl
(
_
)
=>
{
anyhow
::
bail!
(
"
Content::ImageUrl
type is
not
supported"
);
_
=>
{
anyhow
::
bail!
(
"
Only Text
type is supported"
);
}
};
let
r
=
IndexMap
::
from
([
(
"role"
.to_string
(),
Either
::
Left
(
m
.role
.to_string
())),
(
"role"
.to_string
(),
Either
::
Left
(
"user"
.to_string
())),
(
"content"
.to_string
(),
Either
::
Left
(
content
)),
]);
messages
.push
(
r
);
...
...
@@ -204,7 +207,11 @@ impl
// level.
//tracing::info!(prompt_tokens, "Received prompt");
let
limit
=
DEFAULT_MAX_TOKENS
-
prompt_tokens
;
let
max_output_tokens
=
min
(
request
.max_tokens
.unwrap_or
(
limit
),
limit
);
#[allow(deprecated)]
let
max_output_tokens
=
min
(
request
.inner.max_tokens
.map
(|
x
|
x
as
i32
)
.unwrap_or
(
limit
),
limit
,
);
let
mistralrs_request
=
Request
::
Normal
(
NormalRequest
{
messages
:
RequestMessage
::
Chat
(
messages
),
...
...
@@ -247,35 +254,39 @@ impl
.unwrap_or
(
0
);
}
let
finish_reason
=
match
&
c
.choices
[
0
]
.finish_reason
{
Some
(
fr
)
=>
Some
(
fr
.parse
::
<
FinishReason
>
()
.unwrap_or
(
FinishReason
::
null
)),
Some
(
_
fr
)
=>
Some
(
FinishReason
::
Stop
),
//
Some(fr.parse::<FinishReason>().unwrap_or(FinishReason::
Stop
)),
None
if
used_output_tokens
>=
max_output_tokens
=>
{
tracing
::
debug!
(
used_output_tokens
,
max_output_tokens
,
"Met or exceed max_tokens. Stopping."
);
Some
(
FinishReason
::
l
ength
)
Some
(
FinishReason
::
L
ength
)
}
None
=>
None
,
};
//tracing::trace!("from_assistant: {from_assistant}");
let
delta
=
ChatCompletionResponseDelta
{
#[allow(deprecated)]
let
inner
=
async_openai
::
types
::
CreateChatCompletionStreamResponse
{
id
:
c
.id
,
choices
:
vec!
[
ChatCompletionChoiceDelta
{
choices
:
vec!
[
async_openai
::
types
::
ChatChoiceStream
{
index
:
0
,
delta
:
ChatCompletionContent
{
delta
:
async_openai
::
types
::
ChatCompletionStreamResponseDelta
{
//role: c.choices[0].delta.role,
role
:
Some
(
Message
Role
::
a
ssistant
),
role
:
Some
(
async_openai
::
types
::
Role
::
A
ssistant
),
content
:
Some
(
from_assistant
),
tool_calls
:
None
,
refusal
:
None
,
function_call
:
None
,
},
logprobs
:
None
,
finish_reason
,
}],
model
:
c
.model
,
created
:
c
.created
as
u
64
,
created
:
c
.created
as
u
32
,
object
:
c
.object
.clone
(),
usage
:
None
,
system_fingerprint
:
Some
(
c
.system_fingerprint
),
service_tier
:
None
,
};
let
delta
=
ChatCompletionResponseDelta
{
inner
};
let
ann
=
Annotated
{
id
:
None
,
data
:
Some
(
delta
),
...
...
lib/llm/src/http/service/openai.rs
View file @
86aff237
...
...
@@ -220,17 +220,21 @@ async fn chat_completions(
let
request_id
=
uuid
::
Uuid
::
new_v4
()
.to_string
();
// todo - decide on default
let
streaming
=
request
.stream
.unwrap_or
(
false
);
let
streaming
=
request
.
inner.
stream
.unwrap_or
(
false
);
// update the request to always stream
let
request
=
ChatCompletionRequest
{
let
inner_
request
=
async_openai
::
types
::
Create
ChatCompletionRequest
{
stream
:
Some
(
true
),
..
request
..
request
.inner
};
let
request
=
ChatCompletionRequest
{
inner
:
inner_request
,
nvext
:
None
,
};
// todo - make the protocols be optional for model name
// todo - when optional, if none, apply a default
let
model
=
&
request
.model
;
let
model
=
&
request
.
inner.
model
;
// todo - determine the proper error code for when a request model is not present
tracing
::
trace!
(
"Getting chat completions engine for model: {}"
,
model
);
...
...
lib/llm/src/preprocessor.rs
View file @
86aff237
...
...
@@ -275,7 +275,7 @@ impl
let
(
common_request
,
annotations
)
=
self
.preprocess_request
(
&
request
)
?
;
// update isl
response_generator
.update_isl
(
common_request
.token_ids
.len
()
as
i
32
);
response_generator
.update_isl
(
common_request
.token_ids
.len
()
as
u
32
);
// repack the common completion request
let
common_request
=
context
.map
(|
_
|
common_request
);
...
...
lib/llm/src/preprocessor/prompt/template/oai.rs
View file @
86aff237
...
...
@@ -18,35 +18,37 @@ use super::*;
use
minijinja
::{
context
,
value
::
Value
};
use
crate
::
protocols
::
openai
::{
chat_completions
::{
ChatCompletionMessage
,
ChatCompletionRequest
,
Content
,
MessageRole
},
completions
::
CompletionRequest
,
chat_completions
::
ChatCompletionRequest
,
completions
::
CompletionRequest
,
};
use
tracing
;
impl
OAIChatLikeRequest
for
ChatCompletionRequest
{
fn
messages
(
&
self
)
->
Value
{
Value
::
from_serialize
(
&
self
.messages
)
Value
::
from_serialize
(
&
self
.
inner.
messages
)
}
fn
tools
(
&
self
)
->
Option
<
Value
>
{
if
self
.tools
.is_none
()
{
if
self
.
inner.
tools
.is_none
()
{
None
}
else
{
Some
(
Value
::
from_serialize
(
&
self
.tools
))
Some
(
Value
::
from_serialize
(
&
self
.
inner.
tools
))
}
}
fn
tool_choice
(
&
self
)
->
Option
<
Value
>
{
if
self
.tool_choice
.is_none
()
{
if
self
.
inner.
tool_choice
.is_none
()
{
None
}
else
{
Some
(
Value
::
from_serialize
(
&
self
.tool_choice
))
Some
(
Value
::
from_serialize
(
&
self
.
inner.
tool_choice
))
}
}
fn
should_add_generation_prompt
(
&
self
)
->
bool
{
if
let
Some
(
last
)
=
self
.messages
.last
()
{
last
.role
==
MessageRole
::
user
if
let
Some
(
last
)
=
self
.inner.messages
.last
()
{
matches!
(
last
,
async_openai
::
types
::
ChatCompletionRequestMessage
::
User
(
_
)
)
}
else
{
true
}
...
...
@@ -54,13 +56,22 @@ impl OAIChatLikeRequest for ChatCompletionRequest {
}
impl
OAIChatLikeRequest
for
CompletionRequest
{
fn
messages
(
&
self
)
->
Value
{
let
message
=
ChatCompletionMessage
{
role
:
MessageRole
::
user
,
content
:
Content
::
Text
(
self
.prompt
.clone
()),
name
:
None
,
};
Value
::
from_serialize
(
vec!
[
message
])
fn
messages
(
&
self
)
->
minijinja
::
value
::
Value
{
let
message
=
async_openai
::
types
::
ChatCompletionRequestMessage
::
User
(
async_openai
::
types
::
ChatCompletionRequestUserMessage
{
content
:
async_openai
::
types
::
ChatCompletionRequestUserMessageContent
::
Text
(
self
.prompt
.clone
(),
),
name
:
None
,
},
);
// Convert to a JSON string first
let
json_string
=
serde_json
::
to_string
(
&
vec!
[
message
])
.expect
(
"Serialization to JSON string failed"
);
// Convert to MiniJinja Value
minijinja
::
value
::
Value
::
from_safe_string
(
json_string
)
}
fn
should_add_generation_prompt
(
&
self
)
->
bool
{
...
...
lib/llm/src/protocols/common.rs
View file @
86aff237
...
...
@@ -66,6 +66,33 @@ pub enum FinishReason {
Cancelled
,
}
impl
std
::
fmt
::
Display
for
FinishReason
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
match
self
{
FinishReason
::
EoS
=>
write!
(
f
,
"eos"
),
FinishReason
::
Length
=>
write!
(
f
,
"length"
),
FinishReason
::
Stop
=>
write!
(
f
,
"stop"
),
FinishReason
::
Error
(
msg
)
=>
write!
(
f
,
"error: {}"
,
msg
),
FinishReason
::
Cancelled
=>
write!
(
f
,
"cancelled"
),
}
}
}
impl
std
::
str
::
FromStr
for
FinishReason
{
type
Err
=
anyhow
::
Error
;
fn
from_str
(
s
:
&
str
)
->
Result
<
Self
,
Self
::
Err
>
{
match
s
{
"eos"
=>
Ok
(
FinishReason
::
EoS
),
"length"
=>
Ok
(
FinishReason
::
Length
),
"stop"
=>
Ok
(
FinishReason
::
Stop
),
"cancelled"
=>
Ok
(
FinishReason
::
Cancelled
),
s
if
s
.starts_with
(
"error: "
)
=>
Ok
(
FinishReason
::
Error
(
s
[
7
..
]
.to_string
())),
_
=>
Err
(
anyhow
::
anyhow!
(
"Invalid FinishReason variant: '{}'"
,
s
)),
}
}
}
/// LLM Inference Engines can accept a variety of input types. Not all Engines will support all
/// input types. For example, the trtllm::AsyncEngine only supports `PromptType::Tokens` as an
/// input type. The higher-level `Backend` class is a general wrapper around Engines that will
...
...
lib/llm/src/protocols/openai.rs
View file @
86aff237
...
...
@@ -147,9 +147,9 @@ trait OpenAISamplingOptionsProvider {
}
trait
OpenAIStopConditionsProvider
{
fn
get_max_tokens
(
&
self
)
->
Option
<
i
32
>
;
fn
get_max_tokens
(
&
self
)
->
Option
<
u
32
>
;
fn
get_min_tokens
(
&
self
)
->
Option
<
i
32
>
;
fn
get_min_tokens
(
&
self
)
->
Option
<
u
32
>
;
fn
get_stop
(
&
self
)
->
Option
<
Vec
<
String
>>
;
...
...
@@ -200,7 +200,7 @@ impl<T: OpenAISamplingOptionsProvider> SamplingOptionsProvider for T {
impl
<
T
:
OpenAIStopConditionsProvider
>
StopConditionsProvider
for
T
{
fn
extract_stop_conditions
(
&
self
)
->
Result
<
common
::
StopConditions
>
{
let
max_tokens
=
self
.get_max_tokens
()
.map
(|
x
|
x
as
u32
)
;
let
max_tokens
=
self
.get_max_tokens
();
let
min_tokens
=
self
.get_min_tokens
();
let
stop
=
self
.get_stop
();
...
...
@@ -218,7 +218,7 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
Ok
(
common
::
StopConditions
{
max_tokens
,
min_tokens
:
min_tokens
.map
(|
v
|
v
as
u32
)
,
min_tokens
,
stop
,
stop_token_ids_hidden
:
None
,
ignore_eos
,
...
...
@@ -321,6 +321,7 @@ pub trait DeltaGeneratorExt<ResponseType: Send + Sync + 'static + std::fmt::Debu
response
:
common
::
llm_backend
::
BackendOutput
,
)
->
Result
<
ResponseType
>
;
}
#[cfg(test)]
mod
tests
{
...
...
lib/llm/src/protocols/openai/chat_completions.rs
View file @
86aff237
...
...
@@ -13,775 +13,46 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use
std
::
collections
::
HashMap
;
use
std
::
collections
::
VecDeque
;
use
std
::
fmt
;
use
std
::
fmt
::
Display
;
use
derive_builder
::
Builder
;
use
serde
::
de
::{
self
,
SeqAccess
,
Visitor
};
use
serde
::
ser
::
SerializeMap
;
use
super
::
nvext
::
NvExt
;
use
super
::
nvext
::
NvExtProvider
;
use
super
::
OpenAISamplingOptionsProvider
;
use
super
::
OpenAIStopConditionsProvider
;
use
serde
::{
Deserialize
,
Serialize
};
use
serde
::{
Deserializer
,
Serializer
};
use
serde_json
::
Value
;
use
triton_distributed_runtime
::
protocols
::
annotated
::
AnnotationsProvider
;
use
validator
::
Validate
;
mod
aggregator
;
mod
delta
;
use
super
::
nvext
::
NvExtProvider
;
pub
use
super
::{
CompletionTokensDetails
,
CompletionUsage
,
PromptTokensDetails
};
//
pub use aggregator::DeltaAggregator;
pub
use
aggregator
::
DeltaAggregator
;
pub
use
delta
::
DeltaGenerator
;
use
super
::{
common
::{
self
,
ChatCompletionLogprobs
,
SamplingOptionsProvider
,
StopConditionsProvider
},
nvext
::
NvExt
,
validate_logit_bias
,
ContentProvider
,
OpenAISamplingOptionsProvider
,
OpenAIStopConditionsProvider
,
};
use
triton_distributed_runtime
::
protocols
::
annotated
::
AnnotationsProvider
;
/// Request object which is used to generate chat completions.
#[derive(Serialize,
Deserialize,
Builder,
Validate,
Debug,
Clone)]
#[builder(build_fn(private,
name
=
"build_internal"
,
validate
=
"Self::validate"
))]
#[derive(Serialize,
Deserialize,
Validate,
Debug,
Clone)]
pub
struct
ChatCompletionRequest
{
/// Multi-turn chat messages.
///
/// NIM Compatibility:
/// Multi-turn chat models vary, some of which work with the OpenAI ChatGPT format, while others
/// will require `NvExt`.
pub
messages
:
Vec
<
ChatCompletionMessage
>
,
/// Name of the model
#[builder(setter(into))]
pub
model
:
String
,
/// The maximum number of tokens that can be generated in the completion.
/// The token count of your prompt plus max_tokens cannot exceed the model's context length.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default,
setter(into,
strip_option))]
#[validate(range(min
=
1
))]
pub
max_tokens
:
Option
<
i32
>
,
/// The minimum number of tokens to generate. We ignore stop tokens until we see this many
/// tokens. Leave this None unless you are working on the pre-processor.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default,
setter(into,
strip_option))]
pub
min_tokens
:
Option
<
i32
>
,
/// If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only
/// server-sent events as they become available, with the stream terminated by a data: \[DONE\]
///
/// NIM Compatibility:
/// The NIM SDK can send extra meta data in the SSE stream using the `:` comment, `event:`,
/// or `id:` fields. See the `enable_sse_metadata` field in the NvExt object.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default,
setter(strip_option))]
pub
stream
:
Option
<
bool
>
,
/// How many chat completion choices to generate for each input message.
///
/// NIM Compatibility:
/// Values greater than 1 are not currently supported by NIM.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default,
setter(into,
strip_option))]
pub
n
:
Option
<
i32
>
,
/// What sampling `temperature` to use, between 0 and 2. Higher values like 0.8 will make the
/// output more random, while lower values like 0.2 will make it more focused and deterministic.
/// OpenAI defaults to 1.0; however, in this crate, the default is None, and model-specific defaults
/// can be applied later as part of associating the request with a given model.
///
/// OpenAI generally recommend altering this or `top_p` but not both.
///
/// TODO(): Add a model specific validation which could enforce only a single type of sampling can be used.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[validate(range(min
=
"super::MIN_TEMPERATURE"
,
max
=
"super::MAX_TEMPERATURE"
))]
#[builder(default,
setter(into,
strip_option))]
pub
temperature
:
Option
<
f32
>
,
/// An alternative to sampling with `temperature`, called nucleus sampling, where the model
/// considers the results of the tokens with `top_p` probability mass. So 0.1 means only the tokens
/// comprising the top 10% probability mass are considered.
///
/// We generally recommend altering this or `temperature` but not both.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[validate(range(min
=
"super::MIN_TOP_P"
,
max
=
"super::MAX_TOP_P"
))]
#[builder(default,
setter(into,
strip_option))]
pub
top_p
:
Option
<
f32
>
,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency
/// in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[validate(range(
min
=
"super::MIN_FREQUENCY_PENALTY"
,
max
=
"super::MAX_FREQUENCY_PENALTY"
))]
#[builder(default,
setter(into,
strip_option))]
pub
frequency_penalty
:
Option
<
f32
>
,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in
/// the text so far, increasing the model's likelihood to talk about new topics.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[validate(range(
min
=
"super::MIN_PRESENCE_PENALTY"
,
max
=
"super::MAX_PRESENCE_PENALTY"
))]
#[builder(default,
setter(into,
strip_option))]
pub
presence_penalty
:
Option
<
f32
>
,
/// OpenAI specific API fields:
/// See: <https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format>
///
/// NIM Compatibility:
/// This option is not currently supported by NIM LLM. An error will be returned if this field is set.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default)]
pub
response_format
:
Option
<
Value
>
,
/// Up to 4 sequences where the API will stop generating further tokens.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[validate(length(max
=
4
))]
#[builder(default,
setter(into,
strip_option))]
pub
stop
:
Option
<
Vec
<
String
>>
,
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities
/// of each output token returned in the content of message.
///
/// Not all models support logprobs. If logprobs is set to true for a model that does not support it,
/// the request will be processed as if logprobs is set to false.
///
/// NIM Compatibility:
/// TODO - Add a NvExt `strict` object which will disable relaxing of model specific limitations; meaning,
/// if the user requests `logprobs` and the model does not support them, the request will fail wth an error.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default,
setter(strip_option))]
pub
logprobs
:
Option
<
bool
>
,
/// An integer between 0 and 20 specifying the number of most likely tokens to return at each token position,
/// each with an associated log probability. logprobs must be set to true if this parameter is used.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[validate(range(min
=
0
,
max
=
20
))]
#[builder(default,
setter(into,
strip_option))]
pub
top_logprobs
:
Option
<
i32
>
,
/// Modify the likelihood of specified tokens appearing in the completion.
///
/// Accepts a JSON object that maps tokens (specified by their token ID in the GPT tokenizer) to an
/// associated bias value from -100 to 100. You can use this tokenizer tool to convert text to token IDs.
/// Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact
/// effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of
/// selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
///
/// As specified in the OpenAI examples, this is a map of tokens_ids as strings to a bias value that
/// is an integer.
///
/// However, the OpenAI blog using the SDK shows that it can also be specified more accurately as a
/// map of token_ids as ints to a bias value that is also an int.
///
/// NIM Compatibility:
/// In the conversion of the OpenAI request to the internal NIM format, the keys of this map will be
/// validated to ensure they are integers. Since different models may have different tokenizers, the
/// range and values will again be validated on the compute backend to ensure they map to valid tokens
/// in the vocabulary of the model.
///
/// ```
/// use triton_distributed_llm::protocols::openai::completions::CompletionRequest;
///
/// let request = CompletionRequest::builder()
/// .prompt("What is the meaning of life?")
/// .model("meta/llama-3.1-8b-instruct")
/// .add_logit_bias(1337, -100) // using an int as a key is ok
/// .add_logit_bias("42", 100) // using a string as a key is also ok
/// .build()
/// .expect("Should not fail");
///
/// assert!(CompletionRequest::builder()
/// .prompt("What is the meaning of life?")
/// .model("meta/llama-3.1-8b-instruct")
/// .add_logit_bias("some non int", -100)
/// .build()
/// .is_err());
/// ```
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[validate(custom(function
=
"validate_logit_bias"
))]
#[builder(default,
setter(into,
strip_option))]
pub
logit_bias
:
Option
<
HashMap
<
String
,
i32
>>
,
/// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
///
/// NIM Compatibility:
/// If provided, then the value of this field will be included in the trace metadata and the accounting
/// data (if enabled).
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default,
setter(into,
strip_option))]
pub
user
:
Option
<
String
>
,
/// If specified, our system will make a best effort to sample deterministically, such that repeated
/// requests with the same seed and parameters should return the same result. Determinism is not guaranteed,
/// and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default,
setter(into,
strip_option))]
pub
seed
:
Option
<
i64
>
,
/// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to
/// provide a list of functions the model may generate JSON inputs for. A max of 128 functions are supported.
///
/// NIM Compatibility:
/// This field is not currently supported by NIM LLM. An error will be returned if this field is set.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default)]
pub
tools
:
Option
<
Vec
<
Tool
>>
,
/// Controls which (if any) function is called by the model. none means the model will not call a function
/// and instead generates a message. auto means the model can pick between generating a message or calling
/// a function. Specifying a particular function via {"type": "function", "function": {"name": "my_function"}}
/// forces the model to call that function.
///
/// `none` is the default when no functions are present. `auto` is the default if functions are present.
///
/// NIM Compatibility:
/// This field is not currently supported by NIM LLM. An error will be returned if this field is set.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(serialize_with
=
"serialize_tool_choice"
)]
#[builder(default)]
pub
tool_choice
:
Option
<
ToolChoiceType
>
,
/// Additional parameters supported by NIM backends
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default,
setter(strip_option))]
#[serde(flatten)]
pub
inner
:
async_openai
::
types
::
CreateChatCompletionRequest
,
pub
nvext
:
Option
<
NvExt
>
,
}
impl
ChatCompletionRequest
{
pub
fn
builder
()
->
ChatCompletionRequestBuilder
{
ChatCompletionRequestBuilder
::
default
()
}
}
impl
ChatCompletionRequestBuilder
{
// This is a pre-build validate function
// This is called before the generated build method, in this case build_internal, is called
// This has access to the internal state of the builder
fn
validate
(
&
self
)
->
Result
<
(),
String
>
{
Ok
(())
}
/// Builds and validates the ChatCompletionRequest
///
/// ```rust
/// use triton_distributed_llm::protocols::openai::chat_completions::ChatCompletionRequest;
///
/// let request = ChatCompletionRequest::builder()
/// .model("mixtral-8x7b-instruct-v0.1")
/// .add_user_message("Hello")
/// .max_tokens(16)
/// .build()
/// .expect("Failed to build ChatCompletionRequest");
/// ```
pub
fn
build
(
&
self
)
->
anyhow
::
Result
<
ChatCompletionRequest
>
{
// Calls the build_private, validates the result, then performs addition
// post build validation where we are looking a mutually exclusive fields
// and ensuring that there are not mutually exclusive collisions.
let
request
=
self
.build_internal
()
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Failed to build ChatCompletionRequest: {}"
,
e
))
?
;
request
.validate
()
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Failed to validate ChatCompletionRequest: {}"
,
e
))
?
;
// check mutually exclusive fields
if
request
.top_logprobs
.is_some
()
{
if
request
.logprobs
.is_none
()
{
anyhow
::
bail!
(
"top_logprobs requires logprobs to be set to true"
);
}
if
let
Some
(
logprobs
)
=
request
.logprobs
{
if
!
logprobs
{
anyhow
::
bail!
(
"top_logprobs requires logprobs to be set to true"
);
}
}
}
Ok
(
request
)
}
/// Add a message to the `Vec<ChatCompletionMessage>` in the ChatCompletionRequest
/// This will either create or append to the `Vec<ChatCompletionMessage>`
pub
fn
add_message
(
&
mut
self
,
message
:
ChatCompletionMessage
)
->
&
mut
Self
{
// If messages exist we get them or we create new messages with Vec::new
self
.messages
.get_or_insert_with
(
Vec
::
new
)
.push
(
message
);
self
}
/// Add a user message to the `Vec<ChatCompletionMessage>` in the ChatCompletionRequest
pub
fn
add_user_message
(
&
mut
self
,
content
:
impl
Into
<
String
>
)
->
&
mut
Self
{
self
.add_message
(
ChatCompletionMessage
{
role
:
MessageRole
::
user
,
content
:
Content
::
Text
(
content
.into
()),
name
:
None
,
})
}
/// Add an assistant message to the `Vec<ChatCompletionMessage>` in the ChatCompletionRequest
pub
fn
add_assistant_message
(
&
mut
self
,
content
:
impl
Into
<
String
>
)
->
&
mut
Self
{
self
.add_message
(
ChatCompletionMessage
{
role
:
MessageRole
::
assistant
,
content
:
Content
::
Text
(
content
.into
()),
name
:
None
,
})
}
/// Add a system message to the `Vec<ChatCompletionMessage>` in the ChatCompletionRequest
pub
fn
add_system_message
(
&
mut
self
,
content
:
impl
Into
<
String
>
)
->
&
mut
Self
{
self
.add_message
(
ChatCompletionMessage
{
role
:
MessageRole
::
system
,
content
:
Content
::
Text
(
content
.into
()),
name
:
None
,
})
}
/// Add a stop condition to the `Vec<String>` in the ChatCompletionRequest
/// This will either create or append to the `Vec<String>`
pub
fn
add_stop
(
&
mut
self
,
stop
:
impl
Into
<
String
>
)
->
&
mut
Self
{
self
.stop
.get_or_insert_with
(||
Some
(
vec!
[]))
.as_mut
()
.expect
(
"stop should always be Some(Vec)"
)
.push
(
stop
.into
());
self
}
/// Add a token and bias to the `HashMap<String, i32>` in the ChatCompletionRequest
/// This will either create or update the `HashMap<String, i32>`
/// See: [`ChatCompletionRequest::logit_bias`] for more details
pub
fn
add_logit_bias
<
T
>
(
&
mut
self
,
token_id
:
T
,
bias
:
i32
)
->
&
mut
Self
where
T
:
std
::
fmt
::
Display
,
{
self
.logit_bias
.get_or_insert_with
(||
Some
(
HashMap
::
new
()))
.as_mut
()
.expect
(
"logit_bias should always be Some(HashMap)"
)
.insert
(
token_id
.to_string
(),
bias
);
self
}
}
/// Each turn in a conversation is represented by a ChatCompletionMessage.
#[derive(Builder,
Debug,
Deserialize,
Serialize,
Clone)]
pub
struct
ChatCompletionMessage
{
pub
role
:
MessageRole
,
#[serde(deserialize_with
=
"deserialize_content"
)]
pub
content
:
Content
,
#[serde(skip_serializing_if
=
"Option::is_none"
,
default)]
#[builder(default)]
pub
name
:
Option
<
String
>
,
}
#[derive(Debug,
Deserialize,
Serialize,
Clone,
PartialEq,
Eq)]
#[allow(non_camel_case_types)]
pub
enum
MessageRole
{
user
,
system
,
assistant
,
function
,
}
impl
Display
for
MessageRole
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
Result
<
(),
std
::
fmt
::
Error
>
{
use
MessageRole
::
*
;
let
s
=
match
self
{
user
=>
"user"
,
system
=>
"system"
,
assistant
=>
"assistant"
,
function
=>
"function"
,
};
write!
(
f
,
"{s}"
)
}
}
#[derive(Debug,
Deserialize,
Clone,
PartialEq,
Eq)]
pub
enum
Content
{
Text
(
String
),
ImageUrl
(
Vec
<
ImageUrl
>
),
}
impl
serde
::
Serialize
for
Content
{
fn
serialize
<
S
>
(
&
self
,
serializer
:
S
)
->
Result
<
S
::
Ok
,
S
::
Error
>
where
S
:
serde
::
Serializer
,
{
match
*
self
{
Content
::
Text
(
ref
text
)
=>
serializer
.serialize_str
(
text
),
Content
::
ImageUrl
(
ref
image_url
)
=>
image_url
.serialize
(
serializer
),
}
}
}
fn
deserialize_content
<
'de
,
D
>
(
deserializer
:
D
)
->
Result
<
Content
,
D
::
Error
>
where
D
:
Deserializer
<
'de
>
,
{
struct
ContentVisitor
;
impl
<
'de
>
Visitor
<
'de
>
for
ContentVisitor
{
type
Value
=
Content
;
fn
expecting
(
&
self
,
formatter
:
&
mut
fmt
::
Formatter
)
->
fmt
::
Result
{
formatter
.write_str
(
"a string or an array of content parts"
)
}
fn
visit_str
<
E
>
(
self
,
value
:
&
str
)
->
Result
<
Self
::
Value
,
E
>
where
E
:
de
::
Error
,
{
Ok
(
Content
::
Text
(
value
.to_owned
()))
}
fn
visit_seq
<
A
>
(
self
,
mut
seq
:
A
)
->
Result
<
Self
::
Value
,
A
::
Error
>
where
A
:
SeqAccess
<
'de
>
,
{
let
mut
parts
=
Vec
::
new
();
while
let
Some
(
value
)
=
seq
.next_element
::
<
String
>
()
?
{
if
value
.starts_with
(
"http://"
)
||
value
.starts_with
(
"https://"
)
{
parts
.push
(
ImageUrl
{
r
#
type
:
ContentType
::
image_url
,
text
:
None
,
image_url
:
Some
(
ImageUrlType
{
url
:
value
}),
});
}
else
{
parts
.push
(
ImageUrl
{
r
#
type
:
ContentType
::
text
,
text
:
Some
(
value
),
image_url
:
None
,
});
}
}
Ok
(
Content
::
ImageUrl
(
parts
))
}
}
deserializer
.deserialize_any
(
ContentVisitor
)
}
#[derive(Debug,
Deserialize,
Serialize,
Clone,
PartialEq,
Eq)]
#[allow(non_camel_case_types)]
pub
enum
ContentType
{
text
,
image_url
,
}
#[derive(Debug,
Deserialize,
Serialize,
Clone,
PartialEq,
Eq)]
#[allow(non_camel_case_types)]
pub
struct
ImageUrlType
{
pub
url
:
String
,
}
#[derive(Debug,
Deserialize,
Serialize,
Clone,
PartialEq,
Eq)]
#[allow(non_camel_case_types)]
pub
struct
ImageUrl
{
pub
r
#
type
:
ContentType
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
text
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
image_url
:
Option
<
ImageUrlType
>
,
#[derive(Serialize,
Deserialize,
Validate,
Debug,
Clone)]
pub
struct
ChatCompletionResponse
{
#[serde(flatten)]
pub
inner
:
async_openai
::
types
::
CreateChatCompletionResponse
,
}
/// Represents a chat completion response returned by model, based on the provided input.
pub
type
ChatCompletionResponse
=
ChatCompletionGeneric
<
ChatCompletionChoice
>
;
/// Represents a streamed chunk of a chat completion response returned by model, based on the provided input.
pub
type
ChatCompletionResponseDelta
=
ChatCompletionGeneric
<
ChatCompletionChoiceDelta
>
;
/// Common structure for chat completion responses; the only delta is the type of choices which differs
/// between streaming and non-streaming requests.
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
struct
ChatCompletionGeneric
<
C
>
where
C
:
Serialize
+
Clone
+
ContentProvider
,
{
/// A unique identifier for the chat completion.
pub
id
:
String
,
/// A list of chat completion choices. Can be more than one if n is greater than 1.
pub
choices
:
Vec
<
C
>
,
/// The Unix timestamp (in seconds) of when the chat completion was created.
pub
created
:
u64
,
/// The model used for the chat completion.
pub
model
:
String
,
/// The object type, which is `chat.completion` if the type of `Choice` is `ChatCompletionChoice`,
/// or is `chat.completion.chunk` if the type of `Choice` is `ChatCompletionChoiceDelta`.
pub
object
:
String
,
/// Usage information for the completion request.
pub
usage
:
Option
<
CompletionUsage
>
,
/// The service tier used for processing the request, optional.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
service_tier
:
Option
<
ServiceTier
>
,
/// This fingerprint represents the backend configuration that the model runs with.
///
/// Can be used in conjunction with the seed request parameter to understand when backend changes
/// have been made that might impact determinism.
///
/// NIM Compatibility:
/// This field is not supported by the NIM; however it will be added in the future.
/// The optional nature of this field will be relaxed when it is supported.
pub
system_fingerprint
:
Option
<
String
>
,
// TODO() - add NvResponseExtention
}
// Enum for service tier, either "scale" or "default"
#[derive(Debug,
Serialize,
Deserialize,
Clone)]
#[serde(rename_all
=
"snake_case"
)]
pub
enum
ServiceTier
{
Auto
,
Scale
,
Default
,
}
#[derive(Deserialize,
Serialize,
Debug,
Clone)]
pub
struct
ChatCompletionChoice
{
/// A chat completion message generated by the model.
pub
message
:
ChatCompletionContent
,
/// The index of the choice in the list of choices.
pub
index
:
u64
,
/// The reason the model stopped generating tokens. This will be `stop` if the model hit a natural
/// stop point or a provided stop sequence, `length` if the maximum number of tokens specified
/// in the request was reached, `content_filter` if content was omitted due to a flag from our content
/// filters, `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called
/// a function.
///
/// NIM Compatibility:
/// Only `stop` and `length` are currently supported by NIM.
/// NIM may also provide additional reasons in the future, such as `error`, `timeout` or `cancelation`.
pub
finish_reason
:
FinishReason
,
/// Log probability information for the choice, optional field.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
logprobs
:
Option
<
ChatCompletionLogprobs
>
,
}
impl
ContentProvider
for
ChatCompletionChoice
{
fn
content
(
&
self
)
->
String
{
self
.message
.content
()
}
}
/// Same as ChatCompletionMessage, but received during a response stream.
#[derive(Clone,
Debug,
Serialize,
Deserialize)]
pub
struct
ChatCompletionChoiceDelta
{
/// The index of the choice in the list of choices.
pub
index
:
u64
,
/// The reason the model stopped generating tokens. This will be `stop` if the model hit a natural
/// stop point or a provided stop sequence, `length` if the maximum number of tokens specified
/// in the request was reached, `content_filter` if content was omitted due to a flag from our content
/// filters, `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called
/// a function.
///
/// NIM Compatibility:
/// Only `stop` and `length` are currently supported by NIM.
/// NIM may also provide additional reasons in the future, such as `error`, `timeout` or `cancelation`.
pub
finish_reason
:
Option
<
FinishReason
>
,
/// A chat completion delta generated by streamed model responses.
pub
delta
:
ChatCompletionContent
,
/// Log probability information for the choice, optional field.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
logprobs
:
Option
<
ChatCompletionLogprobs
>
,
}
impl
ContentProvider
for
ChatCompletionChoiceDelta
{
fn
content
(
&
self
)
->
String
{
self
.delta
.content
()
}
}
/// A chat completion message generated by the model.
#[derive(Clone,
Debug,
Deserialize,
Serialize)]
#[derive(Serialize,
Deserialize,
Validate,
Debug,
Clone)]
pub
struct
ChatCompletionContent
{
/// The role of the author of this message.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
role
:
Option
<
MessageRole
>
,
/// The contents of the message.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
content
:
Option
<
String
>
,
/// Tool calls made by the model.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
tool_calls
:
Option
<
Vec
<
ToolCall
>>
,
}
impl
ContentProvider
for
ChatCompletionContent
{
fn
content
(
&
self
)
->
String
{
self
.content
.clone
()
.unwrap_or
(
""
.to_string
())
}
#[serde(flatten)]
pub
inner
:
async_openai
::
types
::
ChatCompletionStreamResponseDelta
,
}
#[derive(Debug,
Serialize,
Deserialize,
Clone,
PartialEq,
Eq)]
pub
enum
ToolChoiceType
{
None
,
Auto
,
ToolChoice
{
tool
:
Tool
},
#[derive(Serialize,
Deserialize,
Validate,
Debug,
Clone)]
pub
struct
ChatCompletionResponseDelta
{
#[serde(flatten)]
pub
inner
:
async_openai
::
types
::
CreateChatCompletionStreamResponse
,
}
#[derive(Debug,
Deserialize,
Serialize,
Clone,
PartialEq,
Eq)]
pub
struct
Function
{
pub
name
:
String
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
description
:
Option
<
String
>
,
pub
parameters
:
FunctionParameters
,
}
#[derive(Debug,
Deserialize,
Serialize,
Clone,
PartialEq,
Eq)]
#[serde(rename_all
=
"lowercase"
)]
pub
enum
JSONSchemaType
{
Object
,
Number
,
String
,
Array
,
Null
,
Boolean
,
}
#[derive(Debug,
Deserialize,
Serialize,
Clone,
Default,
PartialEq,
Eq)]
pub
struct
JSONSchemaDefine
{
#[serde(rename
=
"type"
)]
pub
schema_type
:
Option
<
JSONSchemaType
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
description
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
enum_values
:
Option
<
Vec
<
String
>>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
properties
:
Option
<
HashMap
<
String
,
Box
<
JSONSchemaDefine
>>>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
required
:
Option
<
Vec
<
String
>>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
items
:
Option
<
Box
<
JSONSchemaDefine
>>
,
}
#[derive(Debug,
Deserialize,
Serialize,
Clone,
PartialEq,
Eq)]
pub
struct
FunctionParameters
{
#[serde(rename
=
"type"
)]
pub
schema_type
:
JSONSchemaType
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
properties
:
Option
<
HashMap
<
String
,
Box
<
JSONSchemaDefine
>>>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
required
:
Option
<
Vec
<
String
>>
,
}
#[derive(Clone,
Copy,
Debug,
Deserialize,
Serialize,
PartialEq,
Eq)]
#[allow(non_camel_case_types)]
pub
enum
FinishReason
{
stop
,
length
,
content_filter
,
tool_calls
,
cancelled
,
null
,
}
/// from_str trait
impl
std
::
str
::
FromStr
for
FinishReason
{
type
Err
=
String
;
fn
from_str
(
s
:
&
str
)
->
Result
<
Self
,
Self
::
Err
>
{
match
s
{
"stop"
=>
Ok
(
FinishReason
::
stop
),
"length"
=>
Ok
(
FinishReason
::
length
),
"content_filter"
=>
Ok
(
FinishReason
::
content_filter
),
"tool_calls"
=>
Ok
(
FinishReason
::
tool_calls
),
"null"
=>
Ok
(
FinishReason
::
null
),
_
=>
Err
(
format!
(
"Unknown FinishReason: {}"
,
s
)),
}
}
}
impl
std
::
fmt
::
Display
for
FinishReason
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
)
->
std
::
fmt
::
Result
{
match
self
{
FinishReason
::
stop
=>
write!
(
f
,
"stop"
),
FinishReason
::
length
=>
write!
(
f
,
"length"
),
FinishReason
::
content_filter
=>
write!
(
f
,
"content_filter"
),
FinishReason
::
tool_calls
=>
write!
(
f
,
"tool_calls"
),
FinishReason
::
cancelled
=>
write!
(
f
,
"cancelled"
),
FinishReason
::
null
=>
write!
(
f
,
"null"
),
}
}
}
#[derive(Debug,
Deserialize,
Serialize)]
#[allow(non_camel_case_types)]
pub
struct
FinishDetails
{
pub
r
#
type
:
FinishReason
,
pub
stop
:
String
,
}
#[derive(Debug,
Deserialize,
Serialize,
Clone)]
pub
struct
ToolCall
{
pub
id
:
String
,
pub
r
#
type
:
String
,
pub
function
:
ToolCallFunction
,
}
#[derive(Debug,
Deserialize,
Serialize,
Clone)]
pub
struct
ToolCallFunction
{
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
name
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
arguments
:
Option
<
String
>
,
}
fn
serialize_tool_choice
<
S
>
(
value
:
&
Option
<
ToolChoiceType
>
,
serializer
:
S
,
)
->
Result
<
S
::
Ok
,
S
::
Error
>
where
S
:
Serializer
,
{
match
value
{
Some
(
ToolChoiceType
::
None
)
=>
serializer
.serialize_str
(
"none"
),
Some
(
ToolChoiceType
::
Auto
)
=>
serializer
.serialize_str
(
"auto"
),
Some
(
ToolChoiceType
::
ToolChoice
{
tool
})
=>
{
let
mut
map
=
serializer
.serialize_map
(
Some
(
2
))
?
;
map
.serialize_entry
(
"type"
,
&
tool
.r
#
type
)
?
;
map
.serialize_entry
(
"function"
,
&
tool
.function
)
?
;
map
.end
()
}
None
=>
serializer
.serialize_none
(),
}
}
#[derive(Debug,
Deserialize,
Serialize,
Clone,
PartialEq,
Eq)]
pub
struct
Tool
{
pub
r
#
type
:
ToolType
,
pub
function
:
Function
,
}
#[derive(Debug,
Deserialize,
Serialize,
Copy,
Clone,
PartialEq,
Eq)]
#[serde(rename_all
=
"snake_case"
)]
pub
enum
ToolType
{
Function
,
}
impl
ChatCompletionRequest
{}
impl
NvExtProvider
for
ChatCompletionRequest
{
fn
nvext
(
&
self
)
->
Option
<&
NvExt
>
{
self
.nvext
.as_ref
()
...
...
@@ -810,19 +81,19 @@ impl AnnotationsProvider for ChatCompletionRequest {
impl
OpenAISamplingOptionsProvider
for
ChatCompletionRequest
{
fn
get_temperature
(
&
self
)
->
Option
<
f32
>
{
self
.temperature
self
.
inner.
temperature
}
fn
get_top_p
(
&
self
)
->
Option
<
f32
>
{
self
.top_p
self
.
inner.
top_p
}
fn
get_frequency_penalty
(
&
self
)
->
Option
<
f32
>
{
self
.frequency_penalty
self
.
inner.
frequency_penalty
}
fn
get_presence_penalty
(
&
self
)
->
Option
<
f32
>
{
self
.presence_penalty
self
.
inner.
presence_penalty
}
fn
nvext
(
&
self
)
->
Option
<&
NvExt
>
{
...
...
@@ -830,815 +101,26 @@ impl OpenAISamplingOptionsProvider for ChatCompletionRequest {
}
}
#[allow(deprecated)]
impl
OpenAIStopConditionsProvider
for
ChatCompletionRequest
{
fn
get_max_tokens
(
&
self
)
->
Option
<
i32
>
{
self
.max_tokens
fn
get_max_tokens
(
&
self
)
->
Option
<
u32
>
{
// ALLOW: max_tokens is deprecated in favor of max_completion_tokens
self
.inner.max_tokens
}
fn
get_min_tokens
(
&
self
)
->
Option
<
i32
>
{
self
.min_tokens
fn
get_min_tokens
(
&
self
)
->
Option
<
u32
>
{
// TODO THIS IS WRONG min_tokens does not exist
None
}
fn
get_stop
(
&
self
)
->
Option
<
Vec
<
String
>>
{
self
.stop
.clone
()
// TODO THIS IS WRONG should instead do
// Vec<String> -> async_openai::types::Stop
// self.inner.stop.clone()
None
}
fn
nvext
(
&
self
)
->
Option
<&
NvExt
>
{
self
.nvext
.as_ref
()
}
}
/// Implements TryFrom for converting an OpenAI's ChatCompletionRequest to an Engine's CompletionRequest
impl
TryFrom
<
ChatCompletionRequest
>
for
common
::
CompletionRequest
{
type
Error
=
anyhow
::
Error
;
fn
try_from
(
request
:
ChatCompletionRequest
)
->
Result
<
Self
,
Self
::
Error
>
{
// openai_api_rs::v1::chat_completion
// pub struct ChatCompletionRequest {
// NA pub model: String,
// L pub messages: Vec<ChatCompletionMessage, Global>,
// SO pub temperature: Option<f32>,
// SO pub top_p: Option<f32>,
// SO pub n: Option<i32>,
// ** pub response_format: Option<Value>,
// NA pub stream: Option<bool>, // See Issue #8
// SC pub stop: Option<Vec<String, Global>>,
// SC pub max_tokens: Option<i32>,
// SO pub presence_penalty: Option<f32>,
// SO pub frequency_penalty: Option<f32>,
// ** pub logit_bias: Option<HashMap<String, i32, RandomState>>,
// ** pub user: Option<String>,
// SO pub seed: Option<i64>,
// ** pub tools: Option<Vec<Tool, Global>>,
// ** pub tool_choice: Option<ToolChoiceType>,
// }
//
// ** not supported
// NA not applicable
// L local in this method
// SO extract_sampling_options
// SC extract_stop_conditions
// first we validate the OpenAI request
// we can not validate everything as some fields require backend awareness
// however, we can validate against the public OpenAI limit
request
.validate
()
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Failed to validate ChatCompletionRequest: {}"
,
e
))
?
;
// todo(ryan) - open a ticket to support this
if
request
.logit_bias
.is_some
()
{
anyhow
::
bail!
(
"logit_bias is not supported"
);
}
// todo(ryan) - add support for user
if
request
.user
.is_some
()
{
anyhow
::
bail!
(
"user is not supported"
);
}
if
request
.response_format
.is_some
()
{
anyhow
::
bail!
(
"response_format is not supported"
);
}
if
request
.tools
.is_some
()
{
anyhow
::
bail!
(
"tools is not supported"
);
}
if
request
.tool_choice
.is_some
()
{
anyhow
::
bail!
(
"tool_choice is not supported"
);
}
// sampling options
let
sampling_options
=
request
.extract_sampling_options
()
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Failed to extract SamplingOptions: {}"
,
e
))
?
;
// stop conditions
let
stop_conditions
=
request
.extract_stop_conditions
()
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Failed to extract StopConditions: {}"
,
e
))
?
;
// first we need to process the messages
let
prompt
=
common
::
PromptType
::
ChatCompletion
(
validate_and_collect_chat_messages
(
request
.messages
)
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Failed to validate chat messages: {}"
,
e
))
?
,
);
// return the completion request
Ok
(
common
::
CompletionRequest
{
prompt
,
stop_conditions
,
sampling_options
,
mdc_sum
:
None
,
annotations
:
None
,
})
}
}
impl
TryFrom
<
common
::
StreamingCompletionResponse
>
for
ChatCompletionChoice
{
type
Error
=
anyhow
::
Error
;
fn
try_from
(
response
:
common
::
StreamingCompletionResponse
)
->
Result
<
Self
,
Self
::
Error
>
{
let
choice
=
ChatCompletionChoice
{
index
:
response
.delta.index
.unwrap_or
(
0
)
as
u64
,
message
:
ChatCompletionContent
{
role
:
Some
(
MessageRole
::
assistant
),
content
:
response
.delta.text
,
tool_calls
:
None
,
},
finish_reason
:
match
&
response
.delta.finish_reason
{
Some
(
common
::
FinishReason
::
EoS
)
=>
FinishReason
::
stop
,
Some
(
common
::
FinishReason
::
Stop
)
=>
FinishReason
::
stop
,
Some
(
common
::
FinishReason
::
Length
)
=>
FinishReason
::
length
,
Some
(
common
::
FinishReason
::
Error
(
err_msg
))
=>
{
return
Err
(
anyhow
::
anyhow!
(
"finish_reason::error = {}"
,
err_msg
));
}
Some
(
common
::
FinishReason
::
Cancelled
)
=>
FinishReason
::
null
,
None
=>
FinishReason
::
null
,
},
logprobs
:
response
.logprobs
,
};
Ok
(
choice
)
}
}
impl
TryFrom
<
common
::
StreamingCompletionResponse
>
for
ChatCompletionChoiceDelta
{
type
Error
=
anyhow
::
Error
;
fn
try_from
(
response
:
common
::
StreamingCompletionResponse
)
->
Result
<
Self
,
Self
::
Error
>
{
let
choice
=
ChatCompletionChoiceDelta
{
index
:
response
.delta.index
.unwrap_or
(
0
)
as
u64
,
delta
:
ChatCompletionContent
{
role
:
Some
(
MessageRole
::
assistant
),
content
:
response
.delta.text
,
tool_calls
:
None
,
},
finish_reason
:
match
&
response
.delta.finish_reason
{
Some
(
common
::
FinishReason
::
EoS
)
=>
Some
(
FinishReason
::
stop
),
Some
(
common
::
FinishReason
::
Stop
)
=>
Some
(
FinishReason
::
stop
),
Some
(
common
::
FinishReason
::
Length
)
=>
Some
(
FinishReason
::
length
),
Some
(
common
::
FinishReason
::
Error
(
err_msg
))
=>
{
return
Err
(
anyhow
::
anyhow!
(
"finish_reason::error = {}"
,
err_msg
));
}
Some
(
common
::
FinishReason
::
Cancelled
)
=>
Some
(
FinishReason
::
null
),
None
=>
None
,
},
logprobs
:
response
.logprobs
,
};
Ok
(
choice
)
}
}
fn
validate_and_collect_chat_messages
(
messages
:
Vec
<
ChatCompletionMessage
>
,
)
->
Result
<
common
::
ChatContext
,
anyhow
::
Error
>
{
let
mut
system_prompt
=
None
;
let
mut
turns
=
VecDeque
::
new
();
let
mut
last_role
=
MessageRole
::
assistant
;
for
message
in
messages
{
match
message
.role
{
MessageRole
::
system
=>
{
if
system_prompt
.is_some
()
{
return
Err
(
anyhow
::
anyhow!
(
"More than one system message found"
));
}
system_prompt
=
Some
(
message
.content
);
}
MessageRole
::
user
|
MessageRole
::
assistant
=>
{
if
last_role
==
message
.role
{
if
turns
.is_empty
()
{
return
Err
(
anyhow
::
anyhow!
(
"First message must be a user message"
));
}
return
Err
(
anyhow
::
anyhow!
(
"User and assistant messages must alternate"
));
}
last_role
=
message
.role
.clone
();
turns
.push_back
(
message
);
}
MessageRole
::
function
=>
{}
// Ignoring function messages as per assumption.
}
}
if
let
Some
(
first
)
=
turns
.front
()
{
if
let
MessageRole
::
assistant
=
first
.role
{
return
Err
(
anyhow
::
anyhow!
(
"Sequence must start with a user message"
));
}
}
if
turns
.len
()
%
2
==
0
{
return
Err
(
anyhow
::
anyhow!
(
"Sequence must end with a user message"
));
}
let
mut
context
=
Vec
::
new
();
while
turns
.len
()
>=
2
{
let
user
=
turns
.pop_front
()
.unwrap
();
let
asst
=
turns
.pop_front
()
.unwrap
();
let
user
=
match
user
.content
{
Content
::
Text
(
text
)
=>
text
,
_
=>
return
Err
(
anyhow
::
anyhow!
(
"User message must be text"
)),
};
let
asst
=
match
asst
.content
{
Content
::
Text
(
text
)
=>
text
,
_
=>
return
Err
(
anyhow
::
anyhow!
(
"Assistant message must be text"
)),
};
context
.push
(
common
::
ChatTurn
{
user
,
assistant
:
asst
,
});
}
let
prompt
=
turns
.pop_back
()
.unwrap
();
let
prompt
=
match
prompt
.content
{
Content
::
Text
(
text
)
=>
text
,
_
=>
return
Err
(
anyhow
::
anyhow!
(
"Prompt message must be text"
)),
};
let
system_prompt
=
match
system_prompt
{
Some
(
Content
::
Text
(
text
))
=>
Some
(
text
),
Some
(
_
)
=>
return
Err
(
anyhow
::
anyhow!
(
"System prompt must be text"
)),
None
=>
None
,
};
Ok
(
common
::
ChatContext
{
completion
:
common
::
CompletionContext
{
prompt
,
system_prompt
,
},
context
,
})
}
#[cfg(test)]
mod
tests
{
use
anyhow
::
Result
;
use
serde_json
::
json
;
use
std
::
error
::
Error
;
use
super
::
*
;
#[test]
fn
test_chat_completions_valid_request_minimal
()
->
Result
<
(),
Box
<
dyn
Error
>>
{
let
request
=
ChatCompletionRequest
::
builder
()
.model
(
"meta/llama-3.1-8b-instruct"
)
.add_user_message
(
"Hello!"
)
.build
();
assert
!
(
request
.is_ok
(),
"Request should succeed with minimal fields"
);
Ok
(())
}
#[test]
fn
test_chat_completions_valid_request_full
()
->
Result
<
(),
Box
<
dyn
Error
>>
{
let
request
=
ChatCompletionRequest
::
builder
()
.model
(
"meta/llama-3.1-8b-instruct"
)
.add_user_message
(
"Hello!"
)
.max_tokens
(
50
)
.stream
(
true
)
.n
(
1
)
.temperature
(
1.0
)
.top_p
(
0.9
)
.frequency_penalty
(
0.5
)
.presence_penalty
(
0.5
)
.stop
(
vec!
[
"The end."
.to_string
()])
.logprobs
(
true
)
.top_logprobs
(
5
)
.logit_bias
(
HashMap
::
new
())
.user
(
"test_user"
)
.seed
(
1234
)
.build
();
println!
(
"{:?}"
,
request
);
assert
!
(
request
.is_ok
(),
"Request should succeed with all fields set"
);
Ok
(())
}
#[test]
fn
test_chat_completions_top_logprobs_requires_logprobs
()
->
Result
<
(),
Box
<
dyn
Error
>>
{
let
request
=
ChatCompletionRequest
::
builder
()
.model
(
"meta/llama-3.1-8b-instruct"
)
.add_user_message
(
"Hello!"
)
.top_logprobs
(
5
)
// logprobs is not set to true
.build
();
assert
!
(
request
.is_err
(),
"Request should fail when top_logprobs is set without logprobs being true"
);
Ok
(())
}
#[ignore]
#[test]
fn
test_chat_completions_max_tokens_out_of_range
()
->
Result
<
(),
Box
<
dyn
Error
>>
{
let
request
=
ChatCompletionRequest
::
builder
()
.model
(
"meta/llama-3.1-8b-instruct"
)
.add_user_message
(
"Hello!"
)
.max_tokens
(
4097
)
// assuming the model has a max context length of 4096
.build
();
assert
!
(
request
.is_err
(),
"Request should fail when max_tokens exceeds model's context length"
);
Ok
(())
}
#[test]
fn
test_chat_completions_invalid_top_p
()
->
Result
<
(),
Box
<
dyn
Error
>>
{
let
request
=
ChatCompletionRequest
::
builder
()
.model
(
"meta/llama-3.1-8b-instruct"
)
.add_user_message
(
"Hello!"
)
.top_p
(
1.5
)
// Invalid, should be between 0 and 1
.build
();
assert
!
(
request
.is_err
(),
"Request should fail with invalid top_p value"
);
Ok
(())
}
#[test]
fn
test_chat_completions_missing_messages
()
->
Result
<
(),
Box
<
dyn
Error
>>
{
// Missing messages field in the request
let
request_result
=
ChatCompletionRequest
::
builder
()
.model
(
"meta/llama-3.1-8b-instruct"
)
// Valid model
.build
();
// This should fail because no messages are provided.
assert
!
(
request_result
.is_err
(),
"Expected request to fail without messages."
);
if
let
Err
(
e
)
=
request_result
{
println!
(
"Expected error: {}"
,
e
);
// Optionally print the error for debugging
}
Ok
(())
}
#[test]
fn
test_chat_completions_negative_max_tokens
()
->
Result
<
(),
Box
<
dyn
Error
>>
{
let
request
=
ChatCompletionRequest
::
builder
()
.model
(
"meta/llama-3.1-8b-instruct"
)
.add_user_message
(
"Hello, world!"
)
.max_tokens
(
-
10
)
.build
();
assert
!
(
request
.is_err
(),
"Request should fail with negative max_tokens"
);
Ok
(())
}
#[ignore]
#[test]
fn
test_chat_completions_unsupported_logit_bias
()
->
Result
<
(),
Box
<
dyn
Error
>>
{
let
request
=
ChatCompletionRequest
::
builder
()
.model
(
"meta/llama-3.1-8b-instruct"
)
.add_user_message
(
"Hello, world!"
)
.add_logit_bias
(
"50256"
,
-
100
)
.build
();
assert
!
(
request
.is_err
(),
"Request should fail with logit_bias"
);
Ok
(())
}
#[test]
fn
test_chat_completions_invalid_temperature
()
->
Result
<
(),
Box
<
dyn
Error
>>
{
let
request
=
ChatCompletionRequest
::
builder
()
.model
(
"meta/llama-3.1-8b-instruct"
)
.add_user_message
(
"Hello!"
)
.temperature
(
2.5
)
// Invalid, should be between 0 and 2
.build
();
assert
!
(
request
.is_err
(),
"Request should fail with invalid temperature"
);
Ok
(())
}
#[test]
fn
test_chat_completions_max_stop_sequences
()
->
Result
<
(),
Box
<
dyn
Error
>>
{
let
request
=
ChatCompletionRequest
::
builder
()
.model
(
"meta/llama-3.1-8b-instruct"
)
.add_user_message
(
"Tell me a story."
)
.stop
(
vec!
[
"The end."
.to_string
(),
"Once upon a time,"
.to_string
(),
"And then,"
.to_string
(),
"They lived happily ever after."
.to_string
(),
])
// 4 stop sequences, valid
.build
();
assert
!
(
request
.is_ok
(),
"Request should succeed with 4 stop sequences"
);
Ok
(())
}
#[test]
fn
test_chat_completions_large_stop_sequences
()
->
Result
<
(),
Box
<
dyn
Error
>>
{
let
request
=
ChatCompletionRequest
::
builder
()
.model
(
"meta/llama-3.1-8b-instruct"
)
.add_user_message
(
"Tell me a story."
)
.stop
(
vec!
[
"The end."
.to_string
(),
"And so,"
.to_string
(),
"Once upon a time,"
.to_string
(),
"They lived happily ever after."
.to_string
(),
"Unexpected stop."
.to_string
(),
])
.build
();
assert
!
(
request
.is_err
(),
"Request should fail with too many stop sequences"
);
Ok
(())
}
#[ignore]
#[test]
fn
test_chat_completions_invalid_stop_sequences
()
->
Result
<
(),
Box
<
dyn
Error
>>
{
let
request
=
ChatCompletionRequest
::
builder
()
.model
(
"meta/llama-3.1-8b-instruct"
)
.add_user_message
(
"Tell me a joke."
)
.stop
(
vec!
[
""
.to_string
()])
.build
();
assert
!
(
request
.is_err
(),
"Request should fail with invalid stop sequences"
);
Ok
(())
}
#[ignore]
#[test]
fn
test_chat_completions_presence_penalty_out_of_range
()
->
Result
<
(),
Box
<
dyn
Error
>>
{
let
request
=
ChatCompletionRequest
::
builder
()
.model
(
"meta/llama-3.1-8b-instruct"
)
.add_user_message
(
"What's up?"
)
.presence_penalty
(
3.0
)
// Out of valid range (-2.0 to 2.0)
.build
();
assert
!
(
request
.is_err
(),
"Request should fail with invalid presence_penalty"
);
Ok
(())
}
#[test]
fn
test_chat_completions_invalid_presence_penalty
()
->
Result
<
(),
Box
<
dyn
Error
>>
{
let
request
=
ChatCompletionRequest
::
builder
()
.model
(
"meta/llama-3.1-8b-instruct"
)
.add_user_message
(
"What's up?"
)
.presence_penalty
(
-
2.5
)
// Invalid, should be between -2.0 and 2.0
.build
();
assert
!
(
request
.is_err
(),
"Request should fail with invalid presence_penalty"
);
Ok
(())
}
#[ignore]
#[tokio::test]
async
fn
test_chat_completions_with_user_field
()
->
Result
<
(),
Box
<
dyn
Error
>>
{
let
request
=
ChatCompletionRequest
::
builder
()
.model
(
"meta/llama-3.1-8b-instruct"
)
.add_user_message
(
"Hi there!"
)
.user
(
"test_user"
)
.build
()
.unwrap
();
// assert!(request.is_err(), "Request should fail with 'user' field");
let
result
:
Result
<
common
::
CompletionRequest
>
=
request
.try_into
();
assert
!
(
result
.is_err
(),
"Conversion should fail with 'user' field set"
,
);
Ok
(())
}
#[test]
fn
test_chat_completions_valid_with_seed
()
->
Result
<
(),
Box
<
dyn
Error
>>
{
let
request
=
ChatCompletionRequest
::
builder
()
.model
(
"meta/llama-3.1-8b-instruct"
)
.add_user_message
(
"Repeatable result"
)
.seed
(
12345
)
.build
();
assert
!
(
request
.is_ok
(),
"Request should succeed with seed value for determinism"
);
Ok
(())
}
#[test]
fn
test_validate_chat_messages_multiple_system_messages
()
->
Result
<
(),
Box
<
dyn
Error
>>
{
let
request
=
ChatCompletionRequest
::
builder
()
.model
(
"test-model"
)
.add_system_message
(
"System message 1"
)
.add_system_message
(
"System message 2"
)
.add_user_message
(
"Hello!"
)
.build
()
?
;
let
result
=
validate_and_collect_chat_messages
(
request
.messages
.clone
());
assert
!
(
result
.is_err
());
if
let
Err
(
e
)
=
result
{
assert_eq!
(
e
.to_string
(),
"More than one system message found"
);
}
Ok
(())
}
#[test]
fn
test_validate_chat_messages_user_messages_do_not_alternate
()
->
Result
<
(),
Box
<
dyn
Error
>>
{
let
request
=
ChatCompletionRequest
::
builder
()
.model
(
"test-model"
)
.add_user_message
(
"Hello!"
)
.add_user_message
(
"How are you?"
)
.build
()
?
;
let
result
=
validate_and_collect_chat_messages
(
request
.messages
.clone
());
assert
!
(
result
.is_err
());
if
let
Err
(
e
)
=
result
{
assert_eq!
(
e
.to_string
(),
"User and assistant messages must alternate"
);
}
Ok
(())
}
#[ignore]
#[test]
fn
test_validate_chat_messages_user_message_not_text
()
->
Result
<
(),
Box
<
dyn
Error
>>
{
let
message
=
ChatCompletionMessage
{
role
:
MessageRole
::
user
,
content
:
Content
::
ImageUrl
(
vec!
[
ImageUrl
{
r
#
type
:
ContentType
::
image_url
,
text
:
None
,
image_url
:
Some
(
ImageUrlType
{
url
:
"http://example.com/image.png"
.to_string
(),
}),
}]),
name
:
None
,
};
let
request
=
ChatCompletionRequest
::
builder
()
.model
(
"test-model"
)
.add_message
(
message
)
.build
()
?
;
let
result
=
validate_and_collect_chat_messages
(
request
.messages
.clone
());
assert
!
(
result
.is_err
());
if
let
Err
(
e
)
=
result
{
assert_eq!
(
e
.to_string
(),
"Generic error: User message must be text"
);
}
Ok
(())
}
#[test]
fn
test_try_from_chat_completion_request_with_unsupported_fields
()
->
Result
<
(),
Box
<
dyn
Error
>>
{
let
request
=
ChatCompletionRequest
::
builder
()
.model
(
"test-model"
)
.add_user_message
(
"Hello!"
)
.response_format
(
Some
(
json!
({
"format"
:
"unsupported"
})))
.tools
(
Some
(
vec!
[
Tool
{
r
#
type
:
ToolType
::
Function
,
function
:
Function
{
name
:
"test_function"
.to_string
(),
description
:
None
,
parameters
:
FunctionParameters
{
schema_type
:
JSONSchemaType
::
Object
,
properties
:
None
,
required
:
None
,
},
},
}]))
.tool_choice
(
Some
(
ToolChoiceType
::
Auto
))
.build
()
?
;
let
result
:
Result
<
common
::
CompletionRequest
>
=
request
.try_into
();
assert
!
(
result
.is_err
(),
"Conversion should fail with unsupported fields"
);
Ok
(())
}
#[test]
fn
test_deserialize_content_with_image_urls
()
{
let
json_data
=
r#"
{
"role": "assistant",
"content": [
"This is a text message.",
"https://example.com/image1.png",
"Another text message.",
"https://example.com/image2.png"
]
}
"#
;
let
message
:
ChatCompletionMessage
=
serde_json
::
from_str
(
json_data
)
.expect
(
"Deserialization failed"
);
if
let
Content
::
ImageUrl
(
parts
)
=
message
.content
{
assert_eq!
(
parts
.len
(),
4
);
assert_eq!
(
parts
[
0
]
.r
#
type
,
ContentType
::
text
);
assert_eq!
(
parts
[
0
]
.text
.as_ref
()
.unwrap
(),
"This is a text message."
);
assert_eq!
(
parts
[
1
]
.r
#
type
,
ContentType
::
image_url
);
assert_eq!
(
parts
[
1
]
.image_url
.as_ref
()
.unwrap
()
.url
,
"https://example.com/image1.png"
);
}
else
{
panic!
(
"Expected Content::ImageUrl"
);
}
}
#[test]
fn
test_try_from_chat_completion_request_success
()
->
Result
<
(),
Box
<
dyn
Error
>>
{
let
request
=
ChatCompletionRequest
::
builder
()
.model
(
"test-model"
)
.add_user_message
(
"Hello!"
)
.add_assistant_message
(
"Hi there!"
)
.add_user_message
(
"How are you?"
)
.build
()
?
;
let
completion_request
:
common
::
CompletionRequest
=
request
.try_into
()
?
;
assert
!
(
matches!
(
completion_request
.prompt
,
common
::
PromptType
::
ChatCompletion
(
_
)
));
Ok
(())
}
#[test]
fn
test_chat_completion_sampling_params_with_valid_nvext
()
{
let
nvext
=
NvExt
{
ignore_eos
:
Some
(
true
),
repetition_penalty
:
Some
(
0.6
),
top_k
:
Some
(
3
),
use_raw_prompt
:
None
,
greed_sampling
:
None
,
annotations
:
None
,
};
let
request
=
ChatCompletionRequest
::
builder
()
.nvext
(
nvext
)
.model
(
"foo"
)
.add_system_message
(
"Hello!"
)
.build
()
.expect
(
"Failed to build request with valid nvext"
);
assert_eq!
(
request
.nvext
.as_ref
()
.unwrap
()
.ignore_eos
,
Some
(
true
));
assert_eq!
(
request
.nvext
.as_ref
()
.unwrap
()
.repetition_penalty
,
Some
(
0.6
)
);
assert_eq!
(
request
.nvext
.as_ref
()
.unwrap
()
.top_k
,
Some
(
3
));
}
#[test]
fn
test_completion_sampling_params_without_nvext
()
{
let
request
=
ChatCompletionRequest
::
builder
()
.model
(
"foo"
)
.add_user_message
(
"Test"
)
.build
()
.unwrap
();
assert_eq!
(
request
.frequency_penalty
,
None
);
assert_eq!
(
request
.logprobs
,
None
);
}
#[test]
fn
test_completion_sampling_params_with_valid_nvext
()
{
let
nvext
=
NvExt
{
ignore_eos
:
Some
(
true
),
repetition_penalty
:
Some
(
0.6
),
top_k
:
Some
(
3
),
..
Default
::
default
()
};
let
request
=
ChatCompletionRequest
::
builder
()
.nvext
(
nvext
)
.model
(
"foo"
)
.add_user_message
(
"Test"
)
.build
()
.expect
(
"Failed to build request with valid nvext"
);
assert_eq!
(
request
.nvext
.as_ref
()
.unwrap
()
.ignore_eos
,
Some
(
true
));
assert_eq!
(
request
.nvext
.as_ref
()
.unwrap
()
.repetition_penalty
,
Some
(
0.6
)
);
assert_eq!
(
request
.nvext
.as_ref
()
.unwrap
()
.top_k
,
Some
(
3
));
}
// #[test]
// fn test_normalize_unicode_characters() {
// let str = "Hello there how are you\u{E0020}?".to_string();
// let normalized = str.sanitize_text();
// assert_eq!(normalized, "Hello there how are you?");
// }
// #[tokio::test]
// async fn test_chat_completion_request_filtered() {
// // Define input messages with Unicode character to filter
// let messages = vec![
// ChatCompletionMessage {
// role: MessageRole::user,
// content: Content::Text(
// "Hello there how are you\u{E0020}?"
// .to_string()
// .normalize_unicode_characters(),
// ),
// name: None,
// },
// ChatCompletionMessage {
// role: MessageRole::assistant,
// content: Content::Text("How may I help you?".to_string()),
// name: None,
// },
// ChatCompletionMessage {
// role: MessageRole::user,
// content: Content::Text("Do something for me?".to_string()),
// name: None,
// },
// ];
// // Define expected filtered messages
// let expected = vec![
// ChatCompletionMessage {
// role: MessageRole::user,
// content: Content::Text("Hello there how are you?".to_string()),
// name: None,
// },
// ChatCompletionMessage {
// role: MessageRole::assistant,
// content: Content::Text("How may I help you?".to_string()),
// name: None,
// },
// ChatCompletionMessage {
// role: MessageRole::user,
// content: Content::Text("Do something for me?".to_string()),
// name: None,
// },
// ];
// // Build ChatCompletionRequest with filtering applied
// let request = ChatCompletionRequest::builder()
// .model("foo")
// .messages(messages)
// .build()
// .expect("Failed to build ChatCompletionRequest");
// // Validate each message matches the expected filtered content
// for (i, message) in request.messages.iter().enumerate() {
// assert_eq!(message.role, expected[i].role);
// if let Content::Text(ref content) = message.content {
// if let Content::Text(ref expected_content) = expected[i].content {
// assert_eq!(content, expected_content);
// }
// }
// }
// }
}
lib/llm/src/protocols/openai/chat_completions/aggregator.rs
View file @
86aff237
...
...
@@ -13,13 +13,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use
super
::{
ChatCompletionChoice
,
ChatCompletionContent
,
ChatCompletionResponse
,
ChatCompletionResponseDelta
,
CompletionUsage
,
FinishReason
,
MessageRole
,
ServiceTier
,
};
use
super
::{
ChatCompletionResponse
,
ChatCompletionResponseDelta
};
use
crate
::
protocols
::{
codec
::{
Message
,
SseCodecError
},
common
::
ChatCompletionLogprobs
,
convert_sse_stream
,
Annotated
,
};
...
...
@@ -32,21 +28,21 @@ type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>;
pub
struct
DeltaAggregator
{
id
:
String
,
model
:
String
,
created
:
u
64
,
usage
:
Option
<
CompletionUsage
>
,
created
:
u
32
,
usage
:
Option
<
async_openai
::
types
::
CompletionUsage
>
,
system_fingerprint
:
Option
<
String
>
,
choices
:
HashMap
<
u
64
,
DeltaChoice
>
,
choices
:
HashMap
<
u
32
,
DeltaChoice
>
,
error
:
Option
<
String
>
,
service_tier
:
Option
<
ServiceTier
>
,
service_tier
:
Option
<
async_openai
::
types
::
ServiceTier
Response
>
,
}
// Holds the accumulated state of a choice
struct
DeltaChoice
{
index
:
u
64
,
index
:
u
32
,
text
:
String
,
role
:
Option
<
Message
Role
>
,
finish_reason
:
Option
<
FinishReason
>
,
logprobs
:
Option
<
ChatCompletion
Logprobs
>
,
role
:
Option
<
async_openai
::
types
::
Role
>
,
finish_reason
:
Option
<
async_openai
::
types
::
FinishReason
>
,
logprobs
:
Option
<
async_openai
::
types
::
ChatChoice
Logprobs
>
,
}
impl
Default
for
DeltaAggregator
{
...
...
@@ -92,19 +88,19 @@ impl DeltaAggregator {
// TODO(#14) - Aggregate Annotation
let
delta
=
delta
.data
.unwrap
();
aggregator
.id
=
delta
.id
;
aggregator
.model
=
delta
.model
;
aggregator
.created
=
delta
.created
;
aggregator
.service_tier
=
delta
.service_tier
;
if
let
Some
(
usage
)
=
delta
.usage
{
aggregator
.id
=
delta
.
inner.
id
;
aggregator
.model
=
delta
.
inner.
model
;
aggregator
.created
=
delta
.
inner.
created
;
aggregator
.service_tier
=
delta
.
inner.
service_tier
;
if
let
Some
(
usage
)
=
delta
.
inner.
usage
{
aggregator
.usage
=
Some
(
usage
);
}
if
let
Some
(
system_fingerprint
)
=
delta
.system_fingerprint
{
if
let
Some
(
system_fingerprint
)
=
delta
.
inner.
system_fingerprint
{
aggregator
.system_fingerprint
=
Some
(
system_fingerprint
);
}
// handle the choices
for
choice
in
delta
.choices
{
for
choice
in
delta
.
inner.
choices
{
let
state_choice
=
aggregator
.choices
...
...
@@ -141,12 +137,12 @@ impl DeltaAggregator {
let
mut
choices
:
Vec
<
_
>
=
aggregator
.choices
.into_values
()
.map
(
ChatCompletion
Choice
::
from
)
.map
(
async_openai
::
types
::
Chat
Choice
::
from
)
.collect
();
choices
.sort_by
(|
a
,
b
|
a
.index
.cmp
(
&
b
.index
));
Ok
(
ChatCompletionResponse
{
let
inner
=
async_openai
::
types
::
Create
ChatCompletionResponse
{
id
:
aggregator
.id
,
created
:
aggregator
.created
,
usage
:
aggregator
.usage
,
...
...
@@ -155,21 +151,30 @@ impl DeltaAggregator {
system_fingerprint
:
aggregator
.system_fingerprint
,
choices
,
service_tier
:
aggregator
.service_tier
,
})
};
let
response
=
ChatCompletionResponse
{
inner
};
Ok
(
response
)
}
}
// todo - handle tool calls
impl
From
<
DeltaChoice
>
for
ChatCompletionChoice
{
#[allow(deprecated)]
impl
From
<
DeltaChoice
>
for
async_openai
::
types
::
ChatChoice
{
fn
from
(
delta
:
DeltaChoice
)
->
Self
{
ChatCompletionChoice
{
message
:
ChatCompletionContent
{
role
:
delta
.role
,
// ALLOW: function_call is deprecated
async_openai
::
types
::
ChatChoice
{
message
:
async_openai
::
types
::
ChatCompletionResponseMessage
{
role
:
delta
.role
.expect
(
"delta should have a Role"
),
content
:
Some
(
delta
.text
),
tool_calls
:
None
,
refusal
:
None
,
function_call
:
None
,
audio
:
None
,
},
index
:
delta
.index
,
finish_reason
:
delta
.finish_reason
.unwrap_or
(
FinishReason
::
length
)
,
finish_reason
:
delta
.finish_reason
,
logprobs
:
delta
.logprobs
,
}
}
...
...
@@ -192,37 +197,47 @@ impl ChatCompletionResponse {
#[cfg(test)]
mod
tests
{
use
crate
::
protocols
::
openai
::
chat_completions
::
ChatCompletionChoiceDelta
;
use
super
::
*
;
use
futures
::
stream
;
#[allow(deprecated)]
fn
create_test_delta
(
index
:
u
64
,
index
:
u
32
,
text
:
&
str
,
role
:
Option
<
Message
Role
>
,
finish_reason
:
Option
<
FinishReason
>
,
role
:
Option
<
async_openai
::
types
::
Role
>
,
finish_reason
:
Option
<
async_openai
::
types
::
FinishReason
>
,
)
->
Annotated
<
ChatCompletionResponseDelta
>
{
// ALLOW: function_call is deprecated
let
delta
=
async_openai
::
types
::
ChatCompletionStreamResponseDelta
{
content
:
Some
(
text
.to_string
()),
function_call
:
None
,
tool_calls
:
None
,
role
,
refusal
:
None
,
};
let
choice
=
async_openai
::
types
::
ChatChoiceStream
{
index
,
delta
,
finish_reason
,
logprobs
:
None
,
};
let
inner
=
async_openai
::
types
::
CreateChatCompletionStreamResponse
{
id
:
"test_id"
.to_string
(),
model
:
"meta/llama-3.1-8b-instruct"
.to_string
(),
created
:
1234567890
,
service_tier
:
None
,
usage
:
None
,
system_fingerprint
:
None
,
choices
:
vec!
[
choice
],
object
:
"chat.completion"
.to_string
(),
};
let
data
=
ChatCompletionResponseDelta
{
inner
};
Annotated
{
data
:
Some
(
ChatCompletionResponseDelta
{
id
:
"test_id"
.to_string
(),
model
:
"meta/llama-3.1-8b-instruct"
.to_string
(),
created
:
1234567890
,
service_tier
:
None
,
usage
:
None
,
system_fingerprint
:
None
,
choices
:
vec!
[
ChatCompletionChoiceDelta
{
index
,
delta
:
ChatCompletionContent
{
role
,
content
:
Some
(
text
.to_string
()),
tool_calls
:
None
,
},
finish_reason
,
logprobs
:
None
,
}],
object
:
"chat.completion"
.to_string
(),
}),
data
:
Some
(
data
),
id
:
Some
(
"test_id"
.to_string
()),
event
:
None
,
comment
:
None
,
...
...
@@ -242,19 +257,20 @@ mod tests {
let
response
=
result
.unwrap
();
// Verify that the response is empty and has default values
assert_eq!
(
response
.id
,
""
);
assert_eq!
(
response
.model
,
""
);
assert_eq!
(
response
.created
,
0
);
assert
!
(
response
.usage
.is_none
());
assert
!
(
response
.system_fingerprint
.is_none
());
assert_eq!
(
response
.choices
.len
(),
0
);
assert
!
(
response
.service_tier
.is_none
());
assert_eq!
(
response
.
inner.
id
,
""
);
assert_eq!
(
response
.
inner.
model
,
""
);
assert_eq!
(
response
.
inner.
created
,
0
);
assert
!
(
response
.
inner.
usage
.is_none
());
assert
!
(
response
.
inner.
system_fingerprint
.is_none
());
assert_eq!
(
response
.
inner.
choices
.len
(),
0
);
assert
!
(
response
.
inner.
service_tier
.is_none
());
}
#[tokio::test]
async
fn
test_single_delta
()
{
// Create a sample delta
let
annotated_delta
=
create_test_delta
(
0
,
"Hello,"
,
Some
(
MessageRole
::
user
),
None
);
let
annotated_delta
=
create_test_delta
(
0
,
"Hello,"
,
Some
(
async_openai
::
types
::
Role
::
User
),
None
);
// Create a stream
let
stream
=
Box
::
pin
(
stream
::
iter
(
vec!
[
annotated_delta
]));
...
...
@@ -267,18 +283,18 @@ mod tests {
let
response
=
result
.unwrap
();
// Verify the response fields
assert_eq!
(
response
.id
,
"test_id"
);
assert_eq!
(
response
.model
,
"meta/llama-3.1-8b-instruct"
);
assert_eq!
(
response
.created
,
1234567890
);
assert
!
(
response
.usage
.is_none
());
assert
!
(
response
.system_fingerprint
.is_none
());
assert_eq!
(
response
.choices
.len
(),
1
);
let
choice
=
&
response
.choices
[
0
];
assert_eq!
(
response
.
inner.
id
,
"test_id"
);
assert_eq!
(
response
.
inner.
model
,
"meta/llama-3.1-8b-instruct"
);
assert_eq!
(
response
.
inner.
created
,
1234567890
);
assert
!
(
response
.
inner.
usage
.is_none
());
assert
!
(
response
.
inner.
system_fingerprint
.is_none
());
assert_eq!
(
response
.
inner.
choices
.len
(),
1
);
let
choice
=
&
response
.
inner.
choices
[
0
];
assert_eq!
(
choice
.index
,
0
);
assert_eq!
(
choice
.message.content
.as_ref
()
.unwrap
(),
"Hello,"
);
assert
_eq
!
(
choice
.finish_reason
,
FinishReason
::
length
);
assert_eq!
(
choice
.message.role
.as_ref
()
.unwrap
(),
&
Message
Role
::
u
ser
);
assert
!
(
response
.service_tier
.is_none
());
assert
!
(
choice
.finish_reason
.is_none
()
);
assert_eq!
(
choice
.message.role
,
async_openai
::
types
::
Role
::
U
ser
);
assert
!
(
response
.
inner.
service_tier
.is_none
());
}
#[tokio::test]
...
...
@@ -286,8 +302,14 @@ mod tests {
// Create multiple deltas with the same choice index
// One will have a MessageRole and no FinishReason,
// the other will have a FinishReason and no MessageRole
let
annotated_delta1
=
create_test_delta
(
0
,
"Hello,"
,
Some
(
MessageRole
::
user
),
None
);
let
annotated_delta2
=
create_test_delta
(
0
,
" world!"
,
None
,
Some
(
FinishReason
::
stop
));
let
annotated_delta1
=
create_test_delta
(
0
,
"Hello,"
,
Some
(
async_openai
::
types
::
Role
::
User
),
None
);
let
annotated_delta2
=
create_test_delta
(
0
,
" world!"
,
None
,
Some
(
async_openai
::
types
::
FinishReason
::
Stop
),
);
// Create a stream
let
annotated_deltas
=
vec!
[
annotated_delta1
,
annotated_delta2
];
...
...
@@ -301,52 +323,63 @@ mod tests {
let
response
=
result
.unwrap
();
// Verify the response fields
assert_eq!
(
response
.choices
.len
(),
1
);
let
choice
=
&
response
.choices
[
0
];
assert_eq!
(
response
.
inner.
choices
.len
(),
1
);
let
choice
=
&
response
.
inner.
choices
[
0
];
assert_eq!
(
choice
.index
,
0
);
assert_eq!
(
choice
.message.content
.as_ref
()
.unwrap
(),
"Hello, world!"
);
assert_eq!
(
choice
.finish_reason
,
FinishReason
::
stop
);
assert_eq!
(
choice
.message.role
.as_ref
()
.unwrap
(),
&
MessageRole
::
user
);
assert_eq!
(
choice
.finish_reason
,
Some
(
async_openai
::
types
::
FinishReason
::
Stop
)
);
assert_eq!
(
choice
.message.role
,
async_openai
::
types
::
Role
::
User
);
}
#[allow(deprecated)]
#[tokio::test]
async
fn
test_multiple_choices
()
{
// Create a delta with multiple choices
let
delta
=
ChatCompletionResponseDelta
{
// ALLOW: function_call is deprecated
let
delta
=
async_openai
::
types
::
CreateChatCompletionStreamResponse
{
id
:
"test_id"
.to_string
(),
model
:
"test_model"
.to_string
(),
created
:
1234567890
,
service_tier
:
None
,
usage
:
None
,
system_fingerprint
:
None
,
choices
:
vec!
[
ChatCompletionChoiceDelta
{
async_openai
::
types
::
ChatChoiceStream
{
index
:
0
,
delta
:
ChatCompletionContent
{
role
:
Some
(
Message
Role
::
a
ssistant
),
delta
:
async_openai
::
types
::
ChatCompletionStreamResponseDelta
{
role
:
Some
(
async_openai
::
types
::
Role
::
A
ssistant
),
content
:
Some
(
"Choice 0"
.to_string
()),
function_call
:
None
,
tool_calls
:
None
,
refusal
:
None
,
},
finish_reason
:
Some
(
FinishReason
::
s
top
),
finish_reason
:
Some
(
async_openai
::
types
::
FinishReason
::
S
top
),
logprobs
:
None
,
},
ChatCompletionChoiceDelta
{
async_openai
::
types
::
ChatChoiceStream
{
index
:
1
,
delta
:
ChatCompletionContent
{
role
:
Some
(
Message
Role
::
a
ssistant
),
delta
:
async_openai
::
types
::
ChatCompletionStreamResponseDelta
{
role
:
Some
(
async_openai
::
types
::
Role
::
A
ssistant
),
content
:
Some
(
"Choice 1"
.to_string
()),
function_call
:
None
,
tool_calls
:
None
,
refusal
:
None
,
},
finish_reason
:
Some
(
FinishReason
::
s
top
),
finish_reason
:
Some
(
async_openai
::
types
::
FinishReason
::
S
top
),
logprobs
:
None
,
},
],
object
:
"chat.completion"
.to_string
(),
service_tier
:
None
,
};
let
data
=
ChatCompletionResponseDelta
{
inner
:
delta
};
// Wrap it in Annotated and create a stream
let
annotated_delta
=
Annotated
{
data
:
Some
(
d
el
ta
),
data
:
Some
(
d
a
ta
),
id
:
Some
(
"test_id"
.to_string
()),
event
:
None
,
comment
:
None
,
...
...
@@ -361,24 +394,24 @@ mod tests {
let
mut
response
=
result
.unwrap
();
// Verify the response fields
assert_eq!
(
response
.choices
.len
(),
2
);
response
.choices
.sort_by
(|
a
,
b
|
a
.index
.cmp
(
&
b
.index
));
// Ensure the choices are ordered
let
choice0
=
&
response
.choices
[
0
];
assert_eq!
(
response
.
inner.
choices
.len
(),
2
);
response
.
inner.
choices
.sort_by
(|
a
,
b
|
a
.index
.cmp
(
&
b
.index
));
// Ensure the choices are ordered
let
choice0
=
&
response
.
inner.
choices
[
0
];
assert_eq!
(
choice0
.index
,
0
);
assert_eq!
(
choice0
.message.content
.as_ref
()
.unwrap
(),
"Choice 0"
);
assert_eq!
(
choice0
.finish_reason
,
FinishReason
::
stop
);
assert_eq!
(
choice0
.
message.role
.as_ref
()
.unwrap
()
,
&
MessageRole
::
assistant
choice0
.
finish_reason
,
Some
(
async_openai
::
types
::
FinishReason
::
Stop
)
);
assert_eq!
(
choice0
.message.role
,
async_openai
::
types
::
Role
::
Assistant
);
let
choice1
=
&
response
.choices
[
1
];
let
choice1
=
&
response
.
inner.
choices
[
1
];
assert_eq!
(
choice1
.index
,
1
);
assert_eq!
(
choice1
.message.content
.as_ref
()
.unwrap
(),
"Choice 1"
);
assert_eq!
(
choice1
.finish_reason
,
FinishReason
::
stop
);
assert_eq!
(
choice1
.
message.role
.as_ref
()
.unwrap
()
,
&
MessageRole
::
assistant
choice1
.
finish_reason
,
Some
(
async_openai
::
types
::
FinishReason
::
Stop
)
);
assert_eq!
(
choice1
.message.role
,
async_openai
::
types
::
Role
::
Assistant
);
}
}
lib/llm/src/protocols/openai/chat_completions/delta.rs
View file @
86aff237
...
...
@@ -13,12 +13,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use
super
::{
ChatCompletionChoiceDelta
,
ChatCompletionContent
,
ChatCompletionRequest
,
ChatCompletionResponseDelta
,
FinishReason
,
MessageRole
,
ServiceTier
,
};
use
super
::{
ChatCompletionRequest
,
ChatCompletionResponseDelta
};
use
crate
::
protocols
::
common
;
use
crate
::
protocols
::
openai
::
CompletionUsage
;
impl
ChatCompletionRequest
{
// put this method on the request
...
...
@@ -26,10 +22,10 @@ impl ChatCompletionRequest {
pub
fn
response_generator
(
&
self
)
->
DeltaGenerator
{
let
options
=
DeltaGeneratorOptions
{
enable_usage
:
true
,
enable_logprobs
:
self
.logprobs
.unwrap_or
(
false
),
enable_logprobs
:
self
.
inner.
logprobs
.unwrap_or
(
false
),
};
DeltaGenerator
::
new
(
self
.model
.clone
(),
options
)
DeltaGenerator
::
new
(
self
.
inner.
model
.clone
(),
options
)
}
}
...
...
@@ -43,11 +39,11 @@ pub struct DeltaGeneratorOptions {
pub
struct
DeltaGenerator
{
id
:
String
,
object
:
String
,
created
:
u
64
,
created
:
u
32
,
model
:
String
,
system_fingerprint
:
Option
<
String
>
,
service_tier
:
Option
<
ServiceTier
>
,
usage
:
CompletionUsage
,
service_tier
:
Option
<
async_openai
::
types
::
ServiceTier
Response
>
,
usage
:
async_openai
::
types
::
CompletionUsage
,
// counter on how many messages we have issued
msg_counter
:
u64
,
...
...
@@ -57,10 +53,22 @@ pub struct DeltaGenerator {
impl
DeltaGenerator
{
pub
fn
new
(
model
:
String
,
options
:
DeltaGeneratorOptions
)
->
Self
{
// SAFETY: This is a fun one to write. We are casting from u64 to u32
// which typically is unsafe due to loss of precision after it
// exceeds u32::MAX. Fortunately, this won't be an issue until
// 2106. So whoever is still maintaining this then, enjoy!
let
now
=
std
::
time
::
SystemTime
::
now
()
.duration_since
(
std
::
time
::
UNIX_EPOCH
)
.unwrap
()
.as_secs
();
.as_secs
()
as
u32
;
let
usage
=
async_openai
::
types
::
CompletionUsage
{
prompt_tokens
:
0
,
completion_tokens
:
0
,
total_tokens
:
0
,
prompt_tokens_details
:
None
,
completion_tokens_details
:
None
,
};
Self
{
id
:
format!
(
"chatcmpl-{}"
,
uuid
::
Uuid
::
new_v4
()),
...
...
@@ -69,46 +77,54 @@ impl DeltaGenerator {
model
,
system_fingerprint
:
None
,
service_tier
:
None
,
usage
:
CompletionUsage
::
default
()
,
usage
,
msg_counter
:
0
,
options
,
}
}
pub
fn
update_isl
(
&
mut
self
,
isl
:
i
32
)
{
pub
fn
update_isl
(
&
mut
self
,
isl
:
u
32
)
{
self
.usage.prompt_tokens
=
isl
;
}
#[allow(deprecated)]
pub
fn
create_choice
(
&
self
,
index
:
u
64
,
index
:
u
32
,
text
:
Option
<
String
>
,
finish_reason
:
Option
<
super
::
FinishReason
>
,
logprobs
:
Option
<
super
::
ChatCompletion
Logprobs
>
,
)
->
ChatCompletionResponse
Delta
{
//
todo - u
pdate for tool calling
let
delta
=
ChatCompletionContent
{
content
:
text
,
finish_reason
:
Option
<
async_openai
::
types
::
FinishReason
>
,
logprobs
:
Option
<
async_openai
::
types
::
ChatChoice
Logprobs
>
,
)
->
async_openai
::
types
::
Create
ChatCompletion
Stream
Response
{
//
TODO: U
pdate for tool calling
// ALLOW: function_call is deprecated
let
delta
=
async_openai
::
types
::
ChatCompletionStreamResponseDelta
{
role
:
if
self
.msg_counter
==
0
{
Some
(
Message
Role
::
a
ssistant
)
Some
(
async_openai
::
types
::
Role
::
A
ssistant
)
}
else
{
None
},
content
:
text
,
tool_calls
:
None
,
function_call
:
None
,
refusal
:
None
,
};
ChatCompletionResponseDelta
{
let
choice
=
async_openai
::
types
::
ChatChoiceStream
{
index
,
delta
,
finish_reason
,
logprobs
,
};
let
choices
=
vec!
[
choice
];
async_openai
::
types
::
CreateChatCompletionStreamResponse
{
id
:
self
.id
.clone
(),
object
:
self
.object
.clone
(),
created
:
self
.created
,
model
:
self
.model
.clone
(),
system_fingerprint
:
self
.system_fingerprint
.clone
(),
choices
:
vec!
[
ChatCompletionChoiceDelta
{
index
,
delta
,
finish_reason
,
logprobs
,
}],
choices
,
usage
:
if
self
.options.enable_usage
{
Some
(
self
.usage
.clone
())
}
else
{
...
...
@@ -126,17 +142,17 @@ impl crate::protocols::openai::DeltaGeneratorExt<ChatCompletionResponseDelta> fo
)
->
anyhow
::
Result
<
ChatCompletionResponseDelta
>
{
// aggregate usage
if
self
.options.enable_usage
{
self
.usage.completion_tokens
+=
delta
.token_ids
.len
()
as
i
32
;
self
.usage.completion_tokens
+=
delta
.token_ids
.len
()
as
u
32
;
}
// todo logprobs
let
logprobs
=
None
;
let
finish_reason
=
match
delta
.finish_reason
{
Some
(
common
::
FinishReason
::
EoS
)
=>
Some
(
FinishReason
::
s
top
),
Some
(
common
::
FinishReason
::
Stop
)
=>
Some
(
FinishReason
::
s
top
),
Some
(
common
::
FinishReason
::
Length
)
=>
Some
(
FinishReason
::
l
ength
),
Some
(
common
::
FinishReason
::
Cancelled
)
=>
Some
(
FinishReason
::
cancelled
),
Some
(
common
::
FinishReason
::
EoS
)
=>
Some
(
async_openai
::
types
::
FinishReason
::
S
top
),
Some
(
common
::
FinishReason
::
Stop
)
=>
Some
(
async_openai
::
types
::
FinishReason
::
S
top
),
Some
(
common
::
FinishReason
::
Length
)
=>
Some
(
async_openai
::
types
::
FinishReason
::
L
ength
),
Some
(
common
::
FinishReason
::
Cancelled
)
=>
Some
(
async_openai
::
types
::
FinishReason
::
Stop
),
Some
(
common
::
FinishReason
::
Error
(
err_msg
))
=>
{
return
Err
(
anyhow
::
anyhow!
(
err_msg
));
}
...
...
@@ -145,6 +161,10 @@ impl crate::protocols::openai::DeltaGeneratorExt<ChatCompletionResponseDelta> fo
// create choice
let
index
=
0
;
Ok
(
self
.create_choice
(
index
,
delta
.text
,
finish_reason
,
logprobs
))
let
stream_response
=
self
.create_choice
(
index
,
delta
.text
,
finish_reason
,
logprobs
);
Ok
(
ChatCompletionResponseDelta
{
inner
:
stream_response
,
})
}
}
lib/llm/src/protocols/openai/completions.rs
View file @
86aff237
...
...
@@ -22,7 +22,7 @@ use validator::Validate;
mod
aggregator
;
mod
delta
;
pub
use
aggregator
::
DeltaAggregator
;
//
pub use aggregator::DeltaAggregator;
use
super
::{
common
::{
self
,
SamplingOptionsProvider
,
StopConditionsProvider
},
...
...
@@ -56,13 +56,13 @@ pub struct CompletionRequest {
/// The token count of your prompt plus max_tokens cannot exceed the model's context length.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default,
setter(into,
strip_option))]
pub
max_tokens
:
Option
<
i
32
>
,
pub
max_tokens
:
Option
<
u
32
>
,
/// The minimum number of tokens to generate. We ignore stop tokens until we see this many
/// tokens. Leave this None unless you are working on the pre-processor.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default,
setter(into,
strip_option))]
pub
min_tokens
:
Option
<
i
32
>
,
pub
min_tokens
:
Option
<
u
32
>
,
/// If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only
/// server-sent events as they become available, with the stream terminated by a data: \[DONE\]
...
...
@@ -248,7 +248,7 @@ impl CompletionRequestBuilder {
/// let request = CompletionRequest::builder()
/// .model("mixtral-8x7b-instruct-v0.1")
/// .prompt("Hello")
/// .max_tokens(16)
/// .max_tokens(16
_u32
)
/// .build()
/// .expect("Failed to build CompletionRequest");
/// ```
...
...
@@ -433,11 +433,11 @@ impl OpenAISamplingOptionsProvider for CompletionRequest {
}
impl
OpenAIStopConditionsProvider
for
CompletionRequest
{
fn
get_max_tokens
(
&
self
)
->
Option
<
i
32
>
{
fn
get_max_tokens
(
&
self
)
->
Option
<
u
32
>
{
self
.max_tokens
}
fn
get_min_tokens
(
&
self
)
->
Option
<
i
32
>
{
fn
get_min_tokens
(
&
self
)
->
Option
<
u
32
>
{
self
.min_tokens
}
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment