Commit 86aff237 authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: using async_openai


Co-authored-by: default avatarGraham King <grahamk@nvidia.com>
parent d694ca6e
......@@ -151,6 +151,31 @@ version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4288f83726785267c6f2ef073a3d83dc3f9b81464e9f99898240cced85fce35a"
[[package]]
name = "async-openai"
version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d126927c78e1562d7e8473008ac8b082318c04d69e3a83e3495a563f8b84a66"
dependencies = [
"backoff",
"base64 0.22.1",
"bytes",
"derive_builder",
"eventsource-stream",
"futures",
"rand",
"reqwest",
"reqwest-eventsource",
"secrecy",
"serde",
"serde_json",
"thiserror 2.0.11",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
]
[[package]]
name = "async-stream"
version = "0.3.6"
......@@ -320,6 +345,20 @@ dependencies = [
"tracing",
]
[[package]]
name = "backoff"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
dependencies = [
"futures-core",
"getrandom 0.2.15",
"instant",
"pin-project-lite",
"rand",
"tokio",
]
[[package]]
name = "backtrace"
version = "0.3.74"
......@@ -995,6 +1034,17 @@ dependencies = [
"tower-service",
]
[[package]]
name = "eventsource-stream"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
dependencies = [
"futures-core",
"nom",
"pin-project-lite",
]
[[package]]
name = "fastrand"
version = "2.3.0"
......@@ -1157,6 +1207,12 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]]
name = "futures-util"
version = "0.3.31"
......@@ -1202,8 +1258,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
dependencies = [
"cfg-if 1.0.0",
"js-sys",
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
"wasm-bindgen",
]
[[package]]
......@@ -1351,6 +1409,24 @@ dependencies = [
"want",
]
[[package]]
name = "hyper-rustls"
version = "0.27.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2"
dependencies = [
"futures-util",
"http",
"hyper",
"hyper-util",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tower-service",
]
[[package]]
name = "hyper-timeout"
version = "0.5.2"
......@@ -1596,6 +1672,15 @@ version = "0.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb"
[[package]]
name = "instant"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222"
dependencies = [
"cfg-if 1.0.0",
]
[[package]]
name = "iovec"
version = "0.1.4"
......@@ -1605,6 +1690,12 @@ dependencies = [
"libc",
]
[[package]]
name = "ipnet"
version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
[[package]]
name = "is_terminal_polyfill"
version = "1.70.1"
......@@ -1824,6 +1915,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minijinja"
version = "2.7.0"
......@@ -2527,6 +2628,58 @@ dependencies = [
"syn 2.0.98",
]
[[package]]
name = "quinn"
version = "0.11.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef"
dependencies = [
"bytes",
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash",
"rustls",
"socket2",
"thiserror 2.0.11",
"tokio",
"tracing",
]
[[package]]
name = "quinn-proto"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d"
dependencies = [
"bytes",
"getrandom 0.2.15",
"rand",
"ring",
"rustc-hash",
"rustls",
"rustls-pki-types",
"slab",
"thiserror 2.0.11",
"tinyvec",
"tracing",
"web-time",
]
[[package]]
name = "quinn-udp"
version = "0.5.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e46f3055866785f6b92bc6164b76be02ca8f2eb4b002c0354b28cf4c119e5944"
dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2",
"tracing",
"windows-sys 0.52.0",
]
[[package]]
name = "quote"
version = "1.0.38"
......@@ -2661,6 +2814,68 @@ version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "reqwest"
version = "0.12.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da"
dependencies = [
"base64 0.22.1",
"bytes",
"futures-core",
"futures-util",
"http",
"http-body",
"http-body-util",
"hyper",
"hyper-rustls",
"hyper-util",
"ipnet",
"js-sys",
"log",
"mime",
"mime_guess",
"once_cell",
"percent-encoding",
"pin-project-lite",
"quinn",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pemfile",
"rustls-pki-types",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tokio-rustls",
"tokio-util",
"tower 0.5.2",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"wasm-streams",
"web-sys",
"windows-registry",
]
[[package]]
name = "reqwest-eventsource"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde"
dependencies = [
"eventsource-stream",
"futures-core",
"futures-timer",
"mime",
"nom",
"pin-project-lite",
"reqwest",
"thiserror 1.0.69",
]
[[package]]
name = "ring"
version = "0.17.11"
......@@ -2681,6 +2896,12 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc-hash"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
[[package]]
name = "rustc_version"
version = "0.4.1"
......@@ -2757,6 +2978,9 @@ name = "rustls-pki-types"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c"
dependencies = [
"web-time",
]
[[package]]
name = "rustls-webpki"
......@@ -2805,6 +3029,16 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "secrecy"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a"
dependencies = [
"serde",
"zeroize",
]
[[package]]
name = "security-framework"
version = "2.11.1"
......@@ -3123,6 +3357,9 @@ name = "sync_wrapper"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263"
dependencies = [
"futures-core",
]
[[package]]
name = "synstructure"
......@@ -3261,6 +3498,21 @@ dependencies = [
"zerovec",
]
[[package]]
name = "tinyvec"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8"
dependencies = [
"tinyvec_macros",
]
[[package]]
name = "tinyvec_macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokenizers"
version = "0.21.0"
......@@ -3612,6 +3864,7 @@ name = "triton-distributed-llm"
version = "0.2.0"
dependencies = [
"anyhow",
"async-openai",
"async-stream",
"async-trait",
"axum 0.8.1",
......@@ -3740,6 +3993,12 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eeba86d422ce181a719445e51872fa30f1f7413b62becb52e95ec91aa262d85c"
[[package]]
name = "unicase"
version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
[[package]]
name = "unicode-ident"
version = "1.0.17"
......@@ -3957,6 +4216,19 @@ dependencies = [
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-futures"
version = "0.4.50"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61"
dependencies = [
"cfg-if 1.0.0",
"js-sys",
"once_cell",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.100"
......@@ -3989,6 +4261,29 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "wasm-streams"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
dependencies = [
"futures-util",
"js-sys",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]]
name = "web-sys"
version = "0.3.77"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "web-time"
version = "1.1.0"
......@@ -4060,6 +4355,36 @@ dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-registry"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0"
dependencies = [
"windows-result",
"windows-strings",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-result"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e"
dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-strings"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10"
dependencies = [
"windows-result",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-sys"
version = "0.48.0"
......
......@@ -151,6 +151,31 @@ version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4288f83726785267c6f2ef073a3d83dc3f9b81464e9f99898240cced85fce35a"
[[package]]
name = "async-openai"
version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d126927c78e1562d7e8473008ac8b082318c04d69e3a83e3495a563f8b84a66"
dependencies = [
"backoff",
"base64 0.22.1",
"bytes",
"derive_builder",
"eventsource-stream",
"futures",
"rand",
"reqwest",
"reqwest-eventsource",
"secrecy",
"serde",
"serde_json",
"thiserror 2.0.11",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
]
[[package]]
name = "async-stream"
version = "0.3.6"
......@@ -320,6 +345,20 @@ dependencies = [
"tracing",
]
[[package]]
name = "backoff"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
dependencies = [
"futures-core",
"getrandom 0.2.15",
"instant",
"pin-project-lite",
"rand",
"tokio",
]
[[package]]
name = "backtrace"
version = "0.3.74"
......@@ -987,6 +1026,17 @@ dependencies = [
"tower-service",
]
[[package]]
name = "eventsource-stream"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
dependencies = [
"futures-core",
"nom",
"pin-project-lite",
]
[[package]]
name = "fastrand"
version = "2.3.0"
......@@ -1149,6 +1199,12 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]]
name = "futures-util"
version = "0.3.31"
......@@ -1194,8 +1250,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
dependencies = [
"cfg-if 1.0.0",
"js-sys",
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
"wasm-bindgen",
]
[[package]]
......@@ -1362,6 +1420,24 @@ dependencies = [
"want",
]
[[package]]
name = "hyper-rustls"
version = "0.27.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2"
dependencies = [
"futures-util",
"http 1.2.0",
"hyper",
"hyper-util",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tower-service",
]
[[package]]
name = "hyper-timeout"
version = "0.5.2"
......@@ -1607,6 +1683,15 @@ version = "0.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb"
[[package]]
name = "instant"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222"
dependencies = [
"cfg-if 1.0.0",
]
[[package]]
name = "iovec"
version = "0.1.4"
......@@ -1616,6 +1701,12 @@ dependencies = [
"libc",
]
[[package]]
name = "ipnet"
version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
[[package]]
name = "is_terminal_polyfill"
version = "1.70.1"
......@@ -1849,6 +1940,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minijinja"
version = "2.7.0"
......@@ -2563,6 +2664,58 @@ dependencies = [
"syn 2.0.98",
]
[[package]]
name = "quinn"
version = "0.11.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef"
dependencies = [
"bytes",
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash",
"rustls",
"socket2",
"thiserror 2.0.11",
"tokio",
"tracing",
]
[[package]]
name = "quinn-proto"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d"
dependencies = [
"bytes",
"getrandom 0.2.15",
"rand",
"ring",
"rustc-hash",
"rustls",
"rustls-pki-types",
"slab",
"thiserror 2.0.11",
"tinyvec",
"tracing",
"web-time",
]
[[package]]
name = "quinn-udp"
version = "0.5.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e46f3055866785f6b92bc6164b76be02ca8f2eb4b002c0354b28cf4c119e5944"
dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2",
"tracing",
"windows-sys 0.59.0",
]
[[package]]
name = "quote"
version = "1.0.38"
......@@ -2697,6 +2850,68 @@ version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "reqwest"
version = "0.12.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da"
dependencies = [
"base64 0.22.1",
"bytes",
"futures-core",
"futures-util",
"http 1.2.0",
"http-body",
"http-body-util",
"hyper",
"hyper-rustls",
"hyper-util",
"ipnet",
"js-sys",
"log",
"mime",
"mime_guess",
"once_cell",
"percent-encoding",
"pin-project-lite",
"quinn",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pemfile",
"rustls-pki-types",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tokio-rustls",
"tokio-util",
"tower 0.5.2",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"wasm-streams",
"web-sys",
"windows-registry",
]
[[package]]
name = "reqwest-eventsource"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde"
dependencies = [
"eventsource-stream",
"futures-core",
"futures-timer",
"mime",
"nom",
"pin-project-lite",
"reqwest",
"thiserror 1.0.69",
]
[[package]]
name = "ring"
version = "0.17.11"
......@@ -2717,6 +2932,12 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc-hash"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
[[package]]
name = "rustc_version"
version = "0.4.1"
......@@ -2793,6 +3014,9 @@ name = "rustls-pki-types"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c"
dependencies = [
"web-time",
]
[[package]]
name = "rustls-webpki"
......@@ -2841,6 +3065,16 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "secrecy"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a"
dependencies = [
"serde",
"zeroize",
]
[[package]]
name = "security-framework"
version = "2.11.1"
......@@ -3170,6 +3404,9 @@ name = "sync_wrapper"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263"
dependencies = [
"futures-core",
]
[[package]]
name = "synstructure"
......@@ -3331,6 +3568,21 @@ dependencies = [
"zerovec",
]
[[package]]
name = "tinyvec"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8"
dependencies = [
"tinyvec_macros",
]
[[package]]
name = "tinyvec_macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokenizers"
version = "0.21.0"
......@@ -3682,6 +3934,7 @@ name = "triton-distributed-llm"
version = "0.2.0"
dependencies = [
"anyhow",
"async-openai",
"async-stream",
"async-trait",
"axum 0.8.1",
......@@ -3810,6 +4063,12 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eeba86d422ce181a719445e51872fa30f1f7413b62becb52e95ec91aa262d85c"
[[package]]
name = "unicase"
version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
[[package]]
name = "unicode-ident"
version = "1.0.17"
......@@ -4027,6 +4286,19 @@ dependencies = [
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-futures"
version = "0.4.50"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61"
dependencies = [
"cfg-if 1.0.0",
"js-sys",
"once_cell",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.100"
......@@ -4059,6 +4331,29 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "wasm-streams"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
dependencies = [
"futures-util",
"js-sys",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]]
name = "web-sys"
version = "0.3.77"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "web-time"
version = "1.1.0"
......@@ -4130,6 +4425,36 @@ dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-registry"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0"
dependencies = [
"windows-result",
"windows-strings",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-result"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e"
dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-strings"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10"
dependencies = [
"windows-result",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-sys"
version = "0.48.0"
......
......@@ -202,6 +202,31 @@ version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4288f83726785267c6f2ef073a3d83dc3f9b81464e9f99898240cced85fce35a"
[[package]]
name = "async-openai"
version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d126927c78e1562d7e8473008ac8b082318c04d69e3a83e3495a563f8b84a66"
dependencies = [
"backoff",
"base64 0.22.1",
"bytes",
"derive_builder",
"eventsource-stream",
"futures",
"rand",
"reqwest",
"reqwest-eventsource",
"secrecy",
"serde",
"serde_json",
"thiserror 2.0.11",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
]
[[package]]
name = "async-stream"
version = "0.3.6"
......@@ -371,6 +396,20 @@ dependencies = [
"tracing",
]
[[package]]
name = "backoff"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
dependencies = [
"futures-core",
"getrandom 0.2.15",
"instant",
"pin-project-lite",
"rand",
"tokio",
]
[[package]]
name = "backtrace"
version = "0.3.74"
......@@ -1410,6 +1449,17 @@ dependencies = [
"tower-service",
]
[[package]]
name = "eventsource-stream"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
dependencies = [
"futures-core",
"nom",
"pin-project-lite",
]
[[package]]
name = "exr"
version = "1.73.0"
......@@ -1647,6 +1697,12 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]]
name = "futures-util"
version = "0.3.31"
......@@ -1810,8 +1866,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
dependencies = [
"cfg-if 1.0.0",
"js-sys",
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
"wasm-bindgen",
]
[[package]]
......@@ -2022,6 +2080,7 @@ dependencies = [
"hyper",
"hyper-util",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pki-types",
"tokio",
"tokio-rustls",
......@@ -2658,6 +2717,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minijinja"
version = "2.7.0"
......@@ -3731,6 +3800,58 @@ version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3"
[[package]]
name = "quinn"
version = "0.11.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef"
dependencies = [
"bytes",
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash",
"rustls",
"socket2",
"thiserror 2.0.11",
"tokio",
"tracing",
]
[[package]]
name = "quinn-proto"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d"
dependencies = [
"bytes",
"getrandom 0.2.15",
"rand",
"ring",
"rustc-hash",
"rustls",
"rustls-pki-types",
"slab",
"thiserror 2.0.11",
"tinyvec",
"tracing",
"web-time",
]
[[package]]
name = "quinn-udp"
version = "0.5.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e46f3055866785f6b92bc6164b76be02ca8f2eb4b002c0354b28cf4c119e5944"
dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2",
"tracing",
"windows-sys 0.59.0",
]
[[package]]
name = "quote"
version = "1.0.38"
......@@ -3966,11 +4087,16 @@ dependencies = [
"js-sys",
"log",
"mime",
"mime_guess",
"native-tls",
"once_cell",
"percent-encoding",
"pin-project-lite",
"quinn",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pemfile",
"rustls-pki-types",
"serde",
"serde_json",
"serde_urlencoded",
......@@ -3978,15 +4104,34 @@ dependencies = [
"system-configuration",
"tokio",
"tokio-native-tls",
"tokio-rustls",
"tokio-util",
"tower 0.5.2",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"wasm-streams",
"web-sys",
"windows-registry",
]
[[package]]
name = "reqwest-eventsource"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde"
dependencies = [
"eventsource-stream",
"futures-core",
"futures-timer",
"mime",
"nom",
"pin-project-lite",
"reqwest",
"thiserror 1.0.69",
]
[[package]]
name = "ring"
version = "0.17.11"
......@@ -4107,6 +4252,9 @@ name = "rustls-pki-types"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c"
dependencies = [
"web-time",
]
[[package]]
name = "rustls-webpki"
......@@ -4189,6 +4337,16 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "secrecy"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a"
dependencies = [
"serde",
"zeroize",
]
[[package]]
name = "security-framework"
version = "2.11.1"
......@@ -4843,11 +5001,27 @@ dependencies = [
"zerovec",
]
[[package]]
name = "tinyvec"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8"
dependencies = [
"tinyvec_macros",
]
[[package]]
name = "tinyvec_macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tio"
version = "0.1.0"
dependencies = [
"anyhow",
"async-openai",
"async-stream",
"async-trait",
"clap",
......@@ -5276,6 +5450,7 @@ name = "triton-distributed-llm"
version = "0.2.0"
dependencies = [
"anyhow",
"async-openai",
"async-stream",
"async-trait",
"async_zmq",
......@@ -5406,6 +5581,12 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eeba86d422ce181a719445e51872fa30f1f7413b62becb52e95ec91aa262d85c"
[[package]]
name = "unicase"
version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
[[package]]
name = "unicode-ident"
version = "1.0.17"
......@@ -5709,6 +5890,19 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "wasm-streams"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
dependencies = [
"futures-util",
"js-sys",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]]
name = "web-sys"
version = "0.3.77"
......
......@@ -29,6 +29,7 @@ metal = ["triton-distributed-llm/metal"]
[dependencies]
anyhow = "1"
async-openai = "0.27.2"
async-stream = { version = "0.3" }
async-trait = { version = "0.1" }
clap = { version = "4.5", features = ["derive", "env"] }
......
......@@ -21,7 +21,6 @@ use std::{
use triton_distributed_llm::{
backend::Backend,
preprocessor::OpenAIPreprocessor,
protocols::openai::chat_completions::MessageRole,
types::{
openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponseDelta,
......@@ -38,7 +37,7 @@ use triton_distributed_runtime::{
use crate::EngineConfig;
/// Max response tokens for each single query. Must be less than model context size.
const MAX_TOKENS: i32 = 8192;
const MAX_TOKENS: u32 = 8192;
/// Output of `isatty` if the fd is indeed a TTY
const IS_A_TTY: i32 = 1;
......@@ -96,11 +95,12 @@ pub async fn run(
main_loop(cancel_token, &service_name, engine, inspect_template).await
}
#[allow(deprecated)]
async fn main_loop(
cancel_token: CancellationToken,
service_name: &str,
engine: OpenAIChatCompletionsStreamingEngine,
inspect_template: bool,
_inspect_template: bool,
) -> anyhow::Result<()> {
tracing::info!("Ctrl-c to exit");
let theme = dialoguer::theme::ColorfulTheme::default();
......@@ -141,30 +141,31 @@ async fn main_loop(
}
}
};
messages.push((MessageRole::user, prompt.clone()));
// Construct messages
let user_message = async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage {
content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(prompt),
name: None,
},
);
messages.push(user_message);
// Request
let mut req_builder = ChatCompletionRequest::builder();
req_builder
let inner = async_openai::types::CreateChatCompletionRequestArgs::default()
.messages(messages.clone())
.model(service_name)
.stream(true)
.max_tokens(MAX_TOKENS);
if inspect_template {
// This makes the pre-processor ignore stop tokens
req_builder.min_tokens(8192);
}
for (role, msg) in &messages {
match role {
MessageRole::user => {
req_builder.add_user_message(msg);
}
MessageRole::assistant => {
req_builder.add_assistant_message(msg);
}
x => panic!("Only 'user' and 'assistant' messages are supported, not {x}"),
}
}
let req = req_builder.build()?;
.max_tokens(MAX_TOKENS)
.build()?;
// TODO We cannot set min_tokens with async-openai
// if inspect_template {
// // This makes the pre-processor ignore stop tokens
// req_builder.min_tokens(8192);
// }
let req = ChatCompletionRequest { inner, nvext: None };
// Call the model
let mut stream = engine.generate(Context::new(req)).await?;
......@@ -174,7 +175,7 @@ async fn main_loop(
let mut assistant_message = String::new();
while let Some(item) = stream.next().await {
let data = item.data.as_ref().unwrap();
let entry = data.choices.first();
let entry = data.inner.choices.first();
let chat_comp = entry.as_ref().unwrap();
if let Some(c) = &chat_comp.delta.content {
let _ = stdout.write(c.as_bytes());
......@@ -188,7 +189,23 @@ async fn main_loop(
}
println!();
messages.push((MessageRole::assistant, assistant_message));
let assistant_content =
async_openai::types::ChatCompletionRequestAssistantMessageContent::Text(
assistant_message,
);
// ALLOW: function_call is deprecated
let assistant_message = async_openai::types::ChatCompletionRequestMessage::Assistant(
async_openai::types::ChatCompletionRequestAssistantMessage {
content: Some(assistant_content),
refusal: None,
name: None,
audio: None,
tool_calls: None,
function_call: None,
},
);
messages.push(assistant_message);
}
println!();
Ok(())
......
......@@ -18,9 +18,8 @@ use std::{sync::Arc, time::Duration};
use async_stream::stream;
use async_trait::async_trait;
use triton_distributed_llm::protocols::openai::chat_completions::FinishReason;
use triton_distributed_llm::protocols::openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponseDelta, Content,
ChatCompletionRequest, ChatCompletionResponseDelta,
};
use triton_distributed_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
use triton_distributed_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
......@@ -53,25 +52,40 @@ impl
let (request, context) = incoming_request.transfer(());
let deltas = request.response_generator();
let ctx = context.context();
let req = request.messages.into_iter().last().unwrap();
let prompt = match req.content {
Content::Text(prompt) => prompt,
_ => {
anyhow::bail!("Invalid request content field, expected Content::Text");
let req = request.inner.messages.into_iter().last().unwrap();
let prompt = match req {
async_openai::types::ChatCompletionRequestMessage::User(user_msg) => {
match user_msg.content {
async_openai::types::ChatCompletionRequestUserMessageContent::Text(prompt) => {
prompt
}
_ => anyhow::bail!("Invalid request content field, expected Content::Text"),
}
}
_ => anyhow::bail!("Invalid request type, expected User message"),
};
let output = stream! {
let mut id = 1;
for c in prompt.chars() {
// we are returning characters not tokens, so speed up some
tokio::time::sleep(TOKEN_ECHO_DELAY/2).await;
let delta = deltas.create_choice(0, Some(c.to_string()), None, None);
yield Annotated{ id: Some(id.to_string()), data: Some(delta), event: None, comment: None };
let inner = deltas.create_choice(0, Some(c.to_string()), None, None);
let response = ChatCompletionResponseDelta {
inner,
};
yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None };
id += 1;
}
let stop_delta = deltas.create_choice(0, None, Some(FinishReason::stop), None);
yield Annotated { id: Some(id.to_string()), data: Some(stop_delta), event: None, comment: None };
let inner = deltas.create_choice(0, None, Some(async_openai::types::FinishReason::Stop), None);
let response = ChatCompletionResponseDelta {
inner,
};
yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None };
};
Ok(ResponseStream::new(Box::pin(output), ctx))
}
}
......@@ -151,6 +151,31 @@ version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4288f83726785267c6f2ef073a3d83dc3f9b81464e9f99898240cced85fce35a"
[[package]]
name = "async-openai"
version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d126927c78e1562d7e8473008ac8b082318c04d69e3a83e3495a563f8b84a66"
dependencies = [
"backoff",
"base64 0.22.1",
"bytes",
"derive_builder",
"eventsource-stream",
"futures",
"rand",
"reqwest",
"reqwest-eventsource",
"secrecy",
"serde",
"serde_json",
"thiserror 2.0.11",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
]
[[package]]
name = "async-stream"
version = "0.3.6"
......@@ -320,6 +345,20 @@ dependencies = [
"tracing",
]
[[package]]
name = "backoff"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
dependencies = [
"futures-core",
"getrandom 0.2.15",
"instant",
"pin-project-lite",
"rand",
"tokio",
]
[[package]]
name = "backtrace"
version = "0.3.74"
......@@ -986,6 +1025,17 @@ dependencies = [
"tower-service",
]
[[package]]
name = "eventsource-stream"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
dependencies = [
"futures-core",
"nom",
"pin-project-lite",
]
[[package]]
name = "fastrand"
version = "2.3.0"
......@@ -1148,6 +1198,12 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]]
name = "futures-util"
version = "0.3.31"
......@@ -1193,8 +1249,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
dependencies = [
"cfg-if 1.0.0",
"js-sys",
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
"wasm-bindgen",
]
[[package]]
......@@ -1348,6 +1406,24 @@ dependencies = [
"want",
]
[[package]]
name = "hyper-rustls"
version = "0.27.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2"
dependencies = [
"futures-util",
"http",
"hyper",
"hyper-util",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tower-service",
]
[[package]]
name = "hyper-timeout"
version = "0.5.2"
......@@ -1593,6 +1669,15 @@ version = "0.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb"
[[package]]
name = "instant"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222"
dependencies = [
"cfg-if 1.0.0",
]
[[package]]
name = "iovec"
version = "0.1.4"
......@@ -1602,6 +1687,12 @@ dependencies = [
"libc",
]
[[package]]
name = "ipnet"
version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
[[package]]
name = "is_terminal_polyfill"
version = "1.70.1"
......@@ -1842,6 +1933,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minijinja"
version = "2.7.0"
......@@ -2545,6 +2646,58 @@ dependencies = [
"syn 2.0.96",
]
[[package]]
name = "quinn"
version = "0.11.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef"
dependencies = [
"bytes",
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash",
"rustls",
"socket2",
"thiserror 2.0.11",
"tokio",
"tracing",
]
[[package]]
name = "quinn-proto"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d"
dependencies = [
"bytes",
"getrandom 0.2.15",
"rand",
"ring",
"rustc-hash",
"rustls",
"rustls-pki-types",
"slab",
"thiserror 2.0.11",
"tinyvec",
"tracing",
"web-time",
]
[[package]]
name = "quinn-udp"
version = "0.5.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e46f3055866785f6b92bc6164b76be02ca8f2eb4b002c0354b28cf4c119e5944"
dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2",
"tracing",
"windows-sys 0.59.0",
]
[[package]]
name = "quote"
version = "1.0.38"
......@@ -2679,6 +2832,68 @@ version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "reqwest"
version = "0.12.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da"
dependencies = [
"base64 0.22.1",
"bytes",
"futures-core",
"futures-util",
"http",
"http-body",
"http-body-util",
"hyper",
"hyper-rustls",
"hyper-util",
"ipnet",
"js-sys",
"log",
"mime",
"mime_guess",
"once_cell",
"percent-encoding",
"pin-project-lite",
"quinn",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pemfile",
"rustls-pki-types",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tokio-rustls",
"tokio-util",
"tower 0.5.2",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"wasm-streams",
"web-sys",
"windows-registry",
]
[[package]]
name = "reqwest-eventsource"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde"
dependencies = [
"eventsource-stream",
"futures-core",
"futures-timer",
"mime",
"nom",
"pin-project-lite",
"reqwest",
"thiserror 1.0.69",
]
[[package]]
name = "ring"
version = "0.17.8"
......@@ -2700,6 +2915,12 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc-hash"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
[[package]]
name = "rustc_version"
version = "0.4.1"
......@@ -2776,6 +2997,9 @@ name = "rustls-pki-types"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c"
dependencies = [
"web-time",
]
[[package]]
name = "rustls-webpki"
......@@ -2824,6 +3048,16 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "secrecy"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a"
dependencies = [
"serde",
"zeroize",
]
[[package]]
name = "security-framework"
version = "2.11.1"
......@@ -3148,6 +3382,9 @@ name = "sync_wrapper"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263"
dependencies = [
"futures-core",
]
[[package]]
name = "synstructure"
......@@ -3286,6 +3523,21 @@ dependencies = [
"zerovec",
]
[[package]]
name = "tinyvec"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8"
dependencies = [
"tinyvec_macros",
]
[[package]]
name = "tinyvec_macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokenizers"
version = "0.21.0"
......@@ -3637,6 +3889,7 @@ name = "triton-distributed-llm"
version = "0.2.0"
dependencies = [
"anyhow",
"async-openai",
"async-stream",
"async-trait",
"axum 0.8.1",
......@@ -3765,6 +4018,12 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eeba86d422ce181a719445e51872fa30f1f7413b62becb52e95ec91aa262d85c"
[[package]]
name = "unicase"
version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
[[package]]
name = "unicode-ident"
version = "1.0.14"
......@@ -3982,6 +4241,19 @@ dependencies = [
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-futures"
version = "0.4.50"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61"
dependencies = [
"cfg-if 1.0.0",
"js-sys",
"once_cell",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.100"
......@@ -4014,6 +4286,29 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "wasm-streams"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
dependencies = [
"futures-util",
"js-sys",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]]
name = "web-sys"
version = "0.3.77"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "web-time"
version = "1.1.0"
......@@ -4085,6 +4380,36 @@ dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-registry"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0"
dependencies = [
"windows-result",
"windows-strings",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-result"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e"
dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-strings"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10"
dependencies = [
"windows-result",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-sys"
version = "0.48.0"
......
......@@ -163,6 +163,31 @@ version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4288f83726785267c6f2ef073a3d83dc3f9b81464e9f99898240cced85fce35a"
[[package]]
name = "async-openai"
version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d126927c78e1562d7e8473008ac8b082318c04d69e3a83e3495a563f8b84a66"
dependencies = [
"backoff",
"base64 0.22.1",
"bytes",
"derive_builder",
"eventsource-stream",
"futures",
"rand",
"reqwest",
"reqwest-eventsource",
"secrecy",
"serde",
"serde_json",
"thiserror 2.0.11",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
]
[[package]]
name = "async-stream"
version = "0.3.6"
......@@ -332,6 +357,20 @@ dependencies = [
"tracing",
]
[[package]]
name = "backoff"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
dependencies = [
"futures-core",
"getrandom 0.2.15",
"instant",
"pin-project-lite",
"rand",
"tokio",
]
[[package]]
name = "backtrace"
version = "0.3.74"
......@@ -1009,6 +1048,17 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "eventsource-stream"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
dependencies = [
"futures-core",
"nom",
"pin-project-lite",
]
[[package]]
name = "fastrand"
version = "2.3.0"
......@@ -1171,6 +1221,12 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]]
name = "futures-util"
version = "0.3.31"
......@@ -1216,8 +1272,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
dependencies = [
"cfg-if 1.0.0",
"js-sys",
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
"wasm-bindgen",
]
[[package]]
......@@ -1365,6 +1423,24 @@ dependencies = [
"want",
]
[[package]]
name = "hyper-rustls"
version = "0.27.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2"
dependencies = [
"futures-util",
"http",
"hyper",
"hyper-util",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tower-service",
]
[[package]]
name = "hyper-timeout"
version = "0.5.2"
......@@ -1610,6 +1686,15 @@ version = "0.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb"
[[package]]
name = "instant"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222"
dependencies = [
"cfg-if 1.0.0",
]
[[package]]
name = "inventory"
version = "0.3.19"
......@@ -1628,6 +1713,12 @@ dependencies = [
"libc",
]
[[package]]
name = "ipnet"
version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
[[package]]
name = "is_terminal_polyfill"
version = "1.70.1"
......@@ -1847,6 +1938,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minijinja"
version = "2.7.0"
......@@ -2594,6 +2695,58 @@ dependencies = [
"serde",
]
[[package]]
name = "quinn"
version = "0.11.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef"
dependencies = [
"bytes",
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash",
"rustls",
"socket2",
"thiserror 2.0.11",
"tokio",
"tracing",
]
[[package]]
name = "quinn-proto"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d"
dependencies = [
"bytes",
"getrandom 0.2.15",
"rand",
"ring",
"rustc-hash",
"rustls",
"rustls-pki-types",
"slab",
"thiserror 2.0.11",
"tinyvec",
"tracing",
"web-time",
]
[[package]]
name = "quinn-udp"
version = "0.5.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e46f3055866785f6b92bc6164b76be02ca8f2eb4b002c0354b28cf4c119e5944"
dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2",
"tracing",
"windows-sys 0.59.0",
]
[[package]]
name = "quote"
version = "1.0.38"
......@@ -2728,6 +2881,68 @@ version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "reqwest"
version = "0.12.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da"
dependencies = [
"base64 0.22.1",
"bytes",
"futures-core",
"futures-util",
"http",
"http-body",
"http-body-util",
"hyper",
"hyper-rustls",
"hyper-util",
"ipnet",
"js-sys",
"log",
"mime",
"mime_guess",
"once_cell",
"percent-encoding",
"pin-project-lite",
"quinn",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pemfile",
"rustls-pki-types",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tokio-rustls",
"tokio-util",
"tower 0.5.2",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"wasm-streams",
"web-sys",
"windows-registry",
]
[[package]]
name = "reqwest-eventsource"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde"
dependencies = [
"eventsource-stream",
"futures-core",
"futures-timer",
"mime",
"nom",
"pin-project-lite",
"reqwest",
"thiserror 1.0.69",
]
[[package]]
name = "ring"
version = "0.17.8"
......@@ -2749,6 +2964,12 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc-hash"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
[[package]]
name = "rustc_version"
version = "0.4.1"
......@@ -2825,6 +3046,9 @@ name = "rustls-pki-types"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c"
dependencies = [
"web-time",
]
[[package]]
name = "rustls-webpki"
......@@ -2873,6 +3097,16 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "secrecy"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a"
dependencies = [
"serde",
"zeroize",
]
[[package]]
name = "security-framework"
version = "2.11.1"
......@@ -3197,6 +3431,9 @@ name = "sync_wrapper"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263"
dependencies = [
"futures-core",
]
[[package]]
name = "synstructure"
......@@ -3335,6 +3572,21 @@ dependencies = [
"zerovec",
]
[[package]]
name = "tinyvec"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8"
dependencies = [
"tinyvec_macros",
]
[[package]]
name = "tinyvec_macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokenizers"
version = "0.21.0"
......@@ -3686,6 +3938,7 @@ name = "triton-distributed-llm"
version = "0.2.0"
dependencies = [
"anyhow",
"async-openai",
"async-stream",
"async-trait",
"axum 0.8.1",
......@@ -3834,6 +4087,12 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eeba86d422ce181a719445e51872fa30f1f7413b62becb52e95ec91aa262d85c"
[[package]]
name = "unicase"
version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
[[package]]
name = "unicode-ident"
version = "1.0.16"
......@@ -4051,6 +4310,19 @@ dependencies = [
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-futures"
version = "0.4.50"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61"
dependencies = [
"cfg-if 1.0.0",
"js-sys",
"once_cell",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.100"
......@@ -4083,6 +4355,29 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "wasm-streams"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
dependencies = [
"futures-util",
"js-sys",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]]
name = "web-sys"
version = "0.3.77"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "web-time"
version = "1.1.0"
......@@ -4154,6 +4449,36 @@ dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-registry"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0"
dependencies = [
"windows-result",
"windows-strings",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-result"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e"
dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-strings"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10"
dependencies = [
"windows-result",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-sys"
version = "0.48.0"
......
......@@ -129,9 +129,9 @@ dependencies = [
[[package]]
name = "anyhow"
version = "1.0.95"
version = "1.0.96"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04"
checksum = "6b964d184e89d9b6b67dd2715bc8e74cf3107fb2b529990c90cf517326150bf4"
[[package]]
name = "arbitrary"
......@@ -202,6 +202,31 @@ version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4288f83726785267c6f2ef073a3d83dc3f9b81464e9f99898240cced85fce35a"
[[package]]
name = "async-openai"
version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d126927c78e1562d7e8473008ac8b082318c04d69e3a83e3495a563f8b84a66"
dependencies = [
"backoff",
"base64 0.22.1",
"bytes",
"derive_builder",
"eventsource-stream",
"futures",
"rand",
"reqwest",
"reqwest-eventsource",
"secrecy",
"serde",
"serde_json",
"thiserror 2.0.11",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
]
[[package]]
name = "async-stream"
version = "0.3.6"
......@@ -371,6 +396,20 @@ dependencies = [
"tracing",
]
[[package]]
name = "backoff"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
dependencies = [
"futures-core",
"getrandom 0.2.15",
"instant",
"pin-project-lite",
"rand",
"tokio",
]
[[package]]
name = "backtrace"
version = "0.3.74"
......@@ -460,15 +499,16 @@ checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36"
[[package]]
name = "blake3"
version = "1.5.5"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8ee0c1824c4dea5b5f81736aff91bae041d2c07ee1192bec91054e10e3e601e"
checksum = "1230237285e3e10cde447185e8975408ae24deaa67205ce684805c25bc0c7937"
dependencies = [
"arrayref",
"arrayvec",
"cc",
"cfg-if 1.0.0",
"constant_time_eq",
"memmap2",
]
[[package]]
......@@ -662,9 +702,9 @@ dependencies = [
[[package]]
name = "cc"
version = "1.2.12"
version = "1.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "755717a7de9ec452bf7f3f1a3099085deabd7f2962b861dae91ecd7a365903d2"
checksum = "c736e259eea577f443d5c86c304f9f4ae0295c43f3ba05c21f1d66b5f06001af"
dependencies = [
"jobserver",
"libc",
......@@ -730,9 +770,9 @@ dependencies = [
[[package]]
name = "clap"
version = "4.5.29"
version = "4.5.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8acebd8ad879283633b343856142139f2da2317c96b05b4dd6181c61e2480184"
checksum = "027bb0d98429ae334a8698531da7077bdf906419543a35a55c2cb1b66437d767"
dependencies = [
"clap_builder",
"clap_derive",
......@@ -740,9 +780,9 @@ dependencies = [
[[package]]
name = "clap_builder"
version = "4.5.29"
version = "4.5.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6ba32cbda51c7e1dfd49acc1457ba1a7dec5b64fe360e828acb13ca8dc9c2f9"
checksum = "5589e0cba072e0f3d23791efac0fd8627b49c829c196a492e88168e6a669d863"
dependencies = [
"anstream",
"anstyle",
......@@ -990,9 +1030,9 @@ dependencies = [
[[package]]
name = "cudarc"
version = "0.13.4"
version = "0.13.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b68d7c284d40d96a4251330ab583c2718b412f4fc53239d295b3a1f8735f426"
checksum = "4e29ce3bfa797c1183053ceb496316203ef561c183941c3c181500d9ade6daf4"
dependencies = [
"half",
"libloading",
......@@ -1096,9 +1136,9 @@ dependencies = [
[[package]]
name = "data-encoding"
version = "2.7.0"
version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e60eed09d8c01d3cee5b7d30acb059b76614c918fa0f992e0dd6eeb10daad6f"
checksum = "575f75dfd25738df5b91b8e43e14d44bda14637a58fae779fd2b064f8bf3e010"
[[package]]
name = "defmac"
......@@ -1320,9 +1360,9 @@ dependencies = [
[[package]]
name = "either"
version = "1.13.0"
version = "1.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
checksum = "b7914353092ddf589ad78f25c5c1c21b7f80b0ff8621e7c814c3485b5306da9d"
dependencies = [
"serde",
]
......@@ -1382,9 +1422,9 @@ dependencies = [
[[package]]
name = "equivalent"
version = "1.0.1"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
[[package]]
name = "erased-serde"
......@@ -1422,7 +1462,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc0452bcc559431b16f472b7ab86e2f9ccd5f3c2da3795afbd6b773665e047fe"
dependencies = [
"http",
"prost 0.13.4",
"prost 0.13.5",
"tokio",
"tokio-stream",
"tonic",
......@@ -1431,6 +1471,17 @@ dependencies = [
"tower-service",
]
[[package]]
name = "eventsource-stream"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
dependencies = [
"futures-core",
"nom",
"pin-project-lite",
]
[[package]]
name = "exr"
version = "1.73.0"
......@@ -1486,15 +1537,15 @@ dependencies = [
[[package]]
name = "fixedbitset"
version = "0.4.2"
version = "0.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99"
[[package]]
name = "flate2"
version = "1.0.35"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c"
checksum = "11faaf5a5236997af9848be0bef4db95824b1d534ebc64d0f0c6cf3e67bd38dc"
dependencies = [
"crc32fast",
"miniz_oxide",
......@@ -1892,9 +1943,9 @@ dependencies = [
[[package]]
name = "h2"
version = "0.4.7"
version = "0.4.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ccae279728d634d083c00f6099cb58f01cc99c145b84b8be2f6c74618d79922e"
checksum = "5017294ff4bb30944501348f6f8e42e6ad28f42c8bbef7a74029aff064a4e3c2"
dependencies = [
"atomic-waker",
"bytes",
......@@ -2086,6 +2137,7 @@ dependencies = [
"hyper",
"hyper-util",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pki-types",
"tokio",
"tokio-rustls",
......@@ -2551,9 +2603,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8"
[[package]]
name = "libc"
version = "0.2.169"
version = "0.2.170"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a"
checksum = "875b3680cb2f8f71bdcf9a30f38d48282f5d3c95cbf9b3fa57269bb5d5c06828"
[[package]]
name = "libloading"
......@@ -2642,9 +2694,9 @@ dependencies = [
[[package]]
name = "log"
version = "0.4.25"
version = "0.4.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f"
checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e"
[[package]]
name = "lrtable"
......@@ -2757,6 +2809,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minijinja"
version = "2.7.0"
......@@ -2787,9 +2849,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]]
name = "miniz_oxide"
version = "0.8.3"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8402cab7aefae129c6977bb0ff1b8fd9a04eb5b51efc50a70bea51cda0c7924"
checksum = "8e3e04debbb59698c15bacbb6d93584a8c0ca9cc3213cb423d31f760d8843ce5"
dependencies = [
"adler2",
"simd-adler32",
......@@ -2937,7 +2999,7 @@ dependencies = [
"tqdm",
"tracing",
"tracing-subscriber",
"uuid 1.13.1",
"uuid 1.14.0",
"variantly",
"vob",
]
......@@ -3021,9 +3083,9 @@ checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03"
[[package]]
name = "native-tls"
version = "0.2.13"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0dab59f8e050d5df8e4dd87d9206fb6f65a483e20ac9fda365ade4fab353196c"
checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e"
dependencies = [
"libc",
"log",
......@@ -3289,9 +3351,9 @@ dependencies = [
[[package]]
name = "once_cell"
version = "1.20.2"
version = "1.20.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775"
checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e"
[[package]]
name = "onig"
......@@ -3317,9 +3379,9 @@ dependencies = [
[[package]]
name = "openssl"
version = "0.10.70"
version = "0.10.71"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61cfb4e166a8bb8c9b55c500bc2308550148ece889be90f609377e58140f42c6"
checksum = "5e14130c6a98cd258fdcb0fb6d744152343ff729cbfcb28c656a9d12b999fbcd"
dependencies = [
"bitflags 2.8.0",
"cfg-if 1.0.0",
......@@ -3349,9 +3411,9 @@ checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
[[package]]
name = "openssl-sys"
version = "0.9.105"
version = "0.9.106"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b22d5b84be05a8d6947c7cb71f7c849aa0f112acd4bf51c2a7c1c988ac0a9dc"
checksum = "8bb61ea9811cc39e3c2069f40b8b8e2e70d8569b361f879786cc7ed48b777cdd"
dependencies = [
"cc",
"libc",
......@@ -3495,9 +3557,9 @@ dependencies = [
[[package]]
name = "petgraph"
version = "0.6.5"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db"
checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772"
dependencies = [
"fixedbitset",
"indexmap 2.7.1",
......@@ -3566,9 +3628,9 @@ dependencies = [
[[package]]
name = "portable-atomic"
version = "1.10.0"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6"
checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e"
[[package]]
name = "powerfmt"
......@@ -3695,28 +3757,28 @@ dependencies = [
[[package]]
name = "prost"
version = "0.13.4"
version = "0.13.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c0fef6c4230e4ccf618a35c59d7ede15dea37de8427500f50aff708806e42ec"
checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5"
dependencies = [
"bytes",
"prost-derive 0.13.4",
"prost-derive 0.13.5",
]
[[package]]
name = "prost-build"
version = "0.13.4"
version = "0.13.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0f3e5beed80eb580c68e2c600937ac2c4eedabdfd5ef1e5b7ea4f3fba84497b"
checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf"
dependencies = [
"heck 0.5.0",
"itertools 0.13.0",
"itertools 0.14.0",
"log",
"multimap",
"once_cell",
"petgraph",
"prettyplease",
"prost 0.13.4",
"prost 0.13.5",
"prost-types",
"regex",
"syn 2.0.98",
......@@ -3738,12 +3800,12 @@ dependencies = [
[[package]]
name = "prost-derive"
version = "0.13.4"
version = "0.13.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "157c5a9d7ea5c2ed2d9fb8f495b64759f7816c7eaea54ba3978f0d63000162e3"
checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d"
dependencies = [
"anyhow",
"itertools 0.13.0",
"itertools 0.14.0",
"proc-macro2",
"quote",
"syn 2.0.98",
......@@ -3751,11 +3813,11 @@ dependencies = [
[[package]]
name = "prost-types"
version = "0.13.4"
version = "0.13.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc2f1e56baa61e93533aebc21af4d2134b70f66275e0fcdf3cbe43d77ff7e8fc"
checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16"
dependencies = [
"prost 0.13.4",
"prost 0.13.5",
]
[[package]]
......@@ -3900,9 +3962,9 @@ dependencies = [
[[package]]
name = "quinn-udp"
version = "0.5.9"
version = "0.5.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1c40286217b4ba3a71d644d752e6a0b71f13f1b6a2c5311acfcbe0c2418ed904"
checksum = "e46f3055866785f6b92bc6164b76be02ca8f2eb4b002c0354b28cf4c119e5944"
dependencies = [
"cfg_aliases",
"libc",
......@@ -4037,9 +4099,9 @@ checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430"
[[package]]
name = "redox_syscall"
version = "0.5.8"
version = "0.5.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834"
checksum = "82b568323e98e49e2a0899dcee453dd679fae22d69adf9b11dd508d1549b7e2f"
dependencies = [
"bitflags 2.8.0",
]
......@@ -4162,12 +4224,14 @@ dependencies = [
"js-sys",
"log",
"mime",
"mime_guess",
"native-tls",
"once_cell",
"percent-encoding",
"pin-project-lite",
"quinn",
"rustls",
"rustls-native-certs 0.8.1",
"rustls-pemfile",
"rustls-pki-types",
"serde",
......@@ -4190,17 +4254,32 @@ dependencies = [
"windows-registry",
]
[[package]]
name = "reqwest-eventsource"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde"
dependencies = [
"eventsource-stream",
"futures-core",
"futures-timer",
"mime",
"nom",
"pin-project-lite",
"reqwest",
"thiserror 1.0.69",
]
[[package]]
name = "ring"
version = "0.17.8"
version = "0.17.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d"
checksum = "da5349ae27d3887ca812fb375b45a4fbb36d8d12d2df394968cd86e35683fe73"
dependencies = [
"cc",
"cfg-if 1.0.0",
"getrandom 0.2.15",
"libc",
"spin",
"untrusted",
"windows-sys 0.52.0",
]
......@@ -4242,9 +4321,9 @@ checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc-hash"
version = "2.1.0"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7fb8039b3032c191086b10f11f319a6e99e1e82889c5cc6046f515c9db1d497"
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
[[package]]
name = "rustc_version"
......@@ -4270,9 +4349,9 @@ dependencies = [
[[package]]
name = "rustls"
version = "0.23.22"
version = "0.23.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fb9263ab4eb695e42321db096e3b8fbd715a59b154d5c88d82db2175b681ba7"
checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395"
dependencies = [
"log",
"once_cell",
......@@ -4391,9 +4470,9 @@ dependencies = [
[[package]]
name = "schemars"
version = "0.8.21"
version = "0.8.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09c024468a378b7e36765cd36702b7a90cc3cba11654f6685c8f233408e89e92"
checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615"
dependencies = [
"dyn-clone",
"schemars_derive",
......@@ -4403,9 +4482,9 @@ dependencies = [
[[package]]
name = "schemars_derive"
version = "0.8.21"
version = "0.8.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1eee588578aff73f856ab961cd2f79e36bc45d7ded33a7562adba4667aecc0e"
checksum = "32e265784ad618884abaea0600a9adf15393368d840e0222d101a072f3f7534d"
dependencies = [
"proc-macro2",
"quote",
......@@ -4419,6 +4498,16 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "secrecy"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a"
dependencies = [
"serde",
"zeroize",
]
[[package]]
name = "security-framework"
version = "2.11.1"
......@@ -4504,9 +4593,9 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4"
[[package]]
name = "serde"
version = "1.0.217"
version = "1.0.218"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70"
checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60"
dependencies = [
"serde_derive",
]
......@@ -4526,9 +4615,9 @@ dependencies = [
[[package]]
name = "serde_derive"
version = "1.0.217"
version = "1.0.218"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0"
checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b"
dependencies = [
"proc-macro2",
"quote",
......@@ -4548,9 +4637,9 @@ dependencies = [
[[package]]
name = "serde_json"
version = "1.0.138"
version = "1.0.139"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949"
checksum = "44f86c3acccc9c65b153fe1b85a3be07fe5515274ec9f0653b4a0875731c72a6"
dependencies = [
"indexmap 2.7.1",
"itoa",
......@@ -4733,9 +4822,9 @@ dependencies = [
[[package]]
name = "smallvec"
version = "1.13.2"
version = "1.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd"
[[package]]
name = "socket2"
......@@ -4770,12 +4859,6 @@ dependencies = [
"vob",
]
[[package]]
name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
[[package]]
name = "spki"
version = "0.7.3"
......@@ -5272,9 +5355,9 @@ dependencies = [
[[package]]
name = "toktrie"
version = "0.6.28"
version = "0.6.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f9c32a81c3faff7dde7909b471a5c39970e684b7fb88994a0fe4d9a3fb3a2b1"
checksum = "1d6b6bcdb5d6345ffe9504e26906dd61c118c75191355f6219ada2854fe1421b"
dependencies = [
"anyhow",
"bytemuck",
......@@ -5299,16 +5382,16 @@ dependencies = [
[[package]]
name = "toktrie_hf_tokenizers"
version = "0.6.28"
version = "0.6.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e43a051aef243334de35fe3ba70060e11aa373eeb193cc98cce37383d312002"
checksum = "f33000768c9ef47791df9a7da0b2bcd06f758c93dac13af1b9df25a84be0a204"
dependencies = [
"anyhow",
"log",
"serde",
"serde_json",
"tokenizers",
"toktrie 0.6.28",
"toktrie 0.6.29",
]
[[package]]
......@@ -5334,9 +5417,9 @@ dependencies = [
[[package]]
name = "toml_edit"
version = "0.22.23"
version = "0.22.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02a8b472d1a3d7c18e2d61a489aee3453fd9031c33e4f55bd533f4a7adca1bee"
checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474"
dependencies = [
"indexmap 2.7.1",
"serde",
......@@ -5365,7 +5448,7 @@ dependencies = [
"hyper-util",
"percent-encoding",
"pin-project",
"prost 0.13.4",
"prost 0.13.5",
"socket2",
"tokio",
"tokio-stream",
......@@ -5529,6 +5612,7 @@ name = "triton-distributed-llm"
version = "0.2.0"
dependencies = [
"anyhow",
"async-openai",
"async-stream",
"async-trait",
"async_zmq",
......@@ -5568,12 +5652,12 @@ dependencies = [
"tokio",
"tokio-stream",
"tokio-util",
"toktrie 0.6.28",
"toktrie_hf_tokenizers 0.6.28",
"toktrie 0.6.29",
"toktrie_hf_tokenizers 0.6.29",
"tracing",
"triton-distributed-runtime",
"unicode-segmentation",
"uuid 1.13.1",
"uuid 1.14.0",
"validator",
"xxhash-rust",
]
......@@ -5617,7 +5701,7 @@ dependencies = [
"tokio-util",
"tracing",
"tracing-subscriber",
"uuid 1.13.1",
"uuid 1.14.0",
"validator",
"xxhash-rust",
]
......@@ -5647,9 +5731,9 @@ checksum = "0e13db2e0ccd5e14a544e8a246ba2312cd25223f616442d7f2cb0e3db614236e"
[[package]]
name = "typenum"
version = "1.17.0"
version = "1.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f"
[[package]]
name = "ucd-trie"
......@@ -5678,11 +5762,17 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eeba86d422ce181a719445e51872fa30f1f7413b62becb52e95ec91aa262d85c"
[[package]]
name = "unicase"
version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
[[package]]
name = "unicode-ident"
version = "1.0.16"
version = "1.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034"
checksum = "00e2473a93778eb0bad35909dff6a10d28e63f792f16ed15e404fca9d5eeedbe"
[[package]]
name = "unicode-normalization-alignments"
......@@ -5789,9 +5879,9 @@ dependencies = [
[[package]]
name = "uuid"
version = "1.13.1"
version = "1.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ced87ca4be083373936a67f8de945faa23b6b42384bd5b64434850802c6dccd0"
checksum = "93d59ca99a559661b96bf898d8fce28ed87935fd2bea9f05983c1464dd6c71b1"
dependencies = [
"getrandom 0.3.1",
"serde",
......@@ -6280,9 +6370,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "winnow"
version = "0.7.1"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "86e376c75f4f43f44db463cf729e0d3acbf954d13e22c51e26e4c264b4ab545f"
checksum = "0e7f4ea97f6f78012141bcdb6a216b2609f0979ada50b20ca5b52dde2eac2bb1"
dependencies = [
"memchr",
]
......
......@@ -81,6 +81,7 @@ uuid = { workspace = true }
xxhash-rust = { workspace = true }
strum = { workspace = true }
async-openai = "0.27.2"
blake3 = "1"
regex = "1"
......
......@@ -15,6 +15,7 @@
use std::{cmp::min, num::NonZero, path::Path, sync::Arc};
use async_openai::types::FinishReason;
use async_stream::stream;
use async_trait::async_trait;
use either::Either;
......@@ -33,8 +34,7 @@ use triton_distributed_runtime::pipeline::{Error, ManyOut, SingleIn};
use triton_distributed_runtime::protocols::annotated::Annotated;
use crate::protocols::openai::chat_completions::{
ChatCompletionChoiceDelta, ChatCompletionContent, ChatCompletionRequest,
ChatCompletionResponseDelta, Content, FinishReason, MessageRole,
ChatCompletionRequest, ChatCompletionResponseDelta,
};
use crate::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
......@@ -174,11 +174,14 @@ impl
let (tx, mut rx) = channel(10_000);
let maybe_tok = self.pipeline.lock().await.tokenizer();
let mut prompt_tokens = 0;
let mut prompt_tokens = 0i32;
let mut messages = vec![];
for m in request.messages {
let content = match m.content {
Content::Text(prompt) => {
for m in request.inner.messages {
let async_openai::types::ChatCompletionRequestMessage::User(inner_m) = m else {
continue;
};
let content = match inner_m.content {
async_openai::types::ChatCompletionRequestUserMessageContent::Text(prompt) => {
if let Some(tok) = maybe_tok.as_ref() {
prompt_tokens = tok
.encode(prompt.clone(), false)
......@@ -187,12 +190,12 @@ impl
}
prompt
}
Content::ImageUrl(_) => {
anyhow::bail!("Content::ImageUrl type is not supported");
_ => {
anyhow::bail!("Only Text type is supported");
}
};
let r = IndexMap::from([
("role".to_string(), Either::Left(m.role.to_string())),
("role".to_string(), Either::Left("user".to_string())),
("content".to_string(), Either::Left(content)),
]);
messages.push(r);
......@@ -204,7 +207,11 @@ impl
// level.
//tracing::info!(prompt_tokens, "Received prompt");
let limit = DEFAULT_MAX_TOKENS - prompt_tokens;
let max_output_tokens = min(request.max_tokens.unwrap_or(limit), limit);
#[allow(deprecated)]
let max_output_tokens = min(
request.inner.max_tokens.map(|x| x as i32).unwrap_or(limit),
limit,
);
let mistralrs_request = Request::Normal(NormalRequest {
messages: RequestMessage::Chat(messages),
......@@ -247,35 +254,39 @@ impl
.unwrap_or(0);
}
let finish_reason = match &c.choices[0].finish_reason {
Some(fr) => Some(fr.parse::<FinishReason>().unwrap_or(FinishReason::null)),
Some(_fr) => Some(FinishReason::Stop), //Some(fr.parse::<FinishReason>().unwrap_or(FinishReason::Stop)),
None if used_output_tokens >= max_output_tokens => {
tracing::debug!(used_output_tokens, max_output_tokens, "Met or exceed max_tokens. Stopping.");
Some(FinishReason::length)
Some(FinishReason::Length)
}
None => None,
};
//tracing::trace!("from_assistant: {from_assistant}");
let delta = ChatCompletionResponseDelta{
#[allow(deprecated)]
let inner = async_openai::types::CreateChatCompletionStreamResponse{
id: c.id,
choices: vec![ChatCompletionChoiceDelta{
choices: vec![async_openai::types::ChatChoiceStream{
index: 0,
delta: ChatCompletionContent{
delta: async_openai::types::ChatCompletionStreamResponseDelta{
//role: c.choices[0].delta.role,
role: Some(MessageRole::assistant),
role: Some(async_openai::types::Role::Assistant),
content: Some(from_assistant),
tool_calls: None,
refusal: None,
function_call: None,
},
logprobs: None,
finish_reason,
}],
model: c.model,
created: c.created as u64,
created: c.created as u32,
object: c.object.clone(),
usage: None,
system_fingerprint: Some(c.system_fingerprint),
service_tier: None,
};
let delta = ChatCompletionResponseDelta{inner};
let ann = Annotated{
id: None,
data: Some(delta),
......
......@@ -220,17 +220,21 @@ async fn chat_completions(
let request_id = uuid::Uuid::new_v4().to_string();
// todo - decide on default
let streaming = request.stream.unwrap_or(false);
let streaming = request.inner.stream.unwrap_or(false);
// update the request to always stream
let request = ChatCompletionRequest {
let inner_request = async_openai::types::CreateChatCompletionRequest {
stream: Some(true),
..request
..request.inner
};
let request = ChatCompletionRequest {
inner: inner_request,
nvext: None,
};
// todo - make the protocols be optional for model name
// todo - when optional, if none, apply a default
let model = &request.model;
let model = &request.inner.model;
// todo - determine the proper error code for when a request model is not present
tracing::trace!("Getting chat completions engine for model: {}", model);
......
......@@ -275,7 +275,7 @@ impl
let (common_request, annotations) = self.preprocess_request(&request)?;
// update isl
response_generator.update_isl(common_request.token_ids.len() as i32);
response_generator.update_isl(common_request.token_ids.len() as u32);
// repack the common completion request
let common_request = context.map(|_| common_request);
......
......@@ -18,35 +18,37 @@ use super::*;
use minijinja::{context, value::Value};
use crate::protocols::openai::{
chat_completions::{ChatCompletionMessage, ChatCompletionRequest, Content, MessageRole},
completions::CompletionRequest,
chat_completions::ChatCompletionRequest, completions::CompletionRequest,
};
use tracing;
impl OAIChatLikeRequest for ChatCompletionRequest {
fn messages(&self) -> Value {
Value::from_serialize(&self.messages)
Value::from_serialize(&self.inner.messages)
}
fn tools(&self) -> Option<Value> {
if self.tools.is_none() {
if self.inner.tools.is_none() {
None
} else {
Some(Value::from_serialize(&self.tools))
Some(Value::from_serialize(&self.inner.tools))
}
}
fn tool_choice(&self) -> Option<Value> {
if self.tool_choice.is_none() {
if self.inner.tool_choice.is_none() {
None
} else {
Some(Value::from_serialize(&self.tool_choice))
Some(Value::from_serialize(&self.inner.tool_choice))
}
}
fn should_add_generation_prompt(&self) -> bool {
if let Some(last) = self.messages.last() {
last.role == MessageRole::user
if let Some(last) = self.inner.messages.last() {
matches!(
last,
async_openai::types::ChatCompletionRequestMessage::User(_)
)
} else {
true
}
......@@ -54,13 +56,22 @@ impl OAIChatLikeRequest for ChatCompletionRequest {
}
impl OAIChatLikeRequest for CompletionRequest {
fn messages(&self) -> Value {
let message = ChatCompletionMessage {
role: MessageRole::user,
content: Content::Text(self.prompt.clone()),
name: None,
};
Value::from_serialize(vec![message])
fn messages(&self) -> minijinja::value::Value {
let message = async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage {
content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(
self.prompt.clone(),
),
name: None,
},
);
// Convert to a JSON string first
let json_string =
serde_json::to_string(&vec![message]).expect("Serialization to JSON string failed");
// Convert to MiniJinja Value
minijinja::value::Value::from_safe_string(json_string)
}
fn should_add_generation_prompt(&self) -> bool {
......
......@@ -66,6 +66,33 @@ pub enum FinishReason {
Cancelled,
}
impl std::fmt::Display for FinishReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FinishReason::EoS => write!(f, "eos"),
FinishReason::Length => write!(f, "length"),
FinishReason::Stop => write!(f, "stop"),
FinishReason::Error(msg) => write!(f, "error: {}", msg),
FinishReason::Cancelled => write!(f, "cancelled"),
}
}
}
impl std::str::FromStr for FinishReason {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"eos" => Ok(FinishReason::EoS),
"length" => Ok(FinishReason::Length),
"stop" => Ok(FinishReason::Stop),
"cancelled" => Ok(FinishReason::Cancelled),
s if s.starts_with("error: ") => Ok(FinishReason::Error(s[7..].to_string())),
_ => Err(anyhow::anyhow!("Invalid FinishReason variant: '{}'", s)),
}
}
}
/// LLM Inference Engines can accept a variety of input types. Not all Engines will support all
/// input types. For example, the trtllm::AsyncEngine only supports `PromptType::Tokens` as an
/// input type. The higher-level `Backend` class is a general wrapper around Engines that will
......
......@@ -147,9 +147,9 @@ trait OpenAISamplingOptionsProvider {
}
trait OpenAIStopConditionsProvider {
fn get_max_tokens(&self) -> Option<i32>;
fn get_max_tokens(&self) -> Option<u32>;
fn get_min_tokens(&self) -> Option<i32>;
fn get_min_tokens(&self) -> Option<u32>;
fn get_stop(&self) -> Option<Vec<String>>;
......@@ -200,7 +200,7 @@ impl<T: OpenAISamplingOptionsProvider> SamplingOptionsProvider for T {
impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
fn extract_stop_conditions(&self) -> Result<common::StopConditions> {
let max_tokens = self.get_max_tokens().map(|x| x as u32);
let max_tokens = self.get_max_tokens();
let min_tokens = self.get_min_tokens();
let stop = self.get_stop();
......@@ -218,7 +218,7 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
Ok(common::StopConditions {
max_tokens,
min_tokens: min_tokens.map(|v| v as u32),
min_tokens,
stop,
stop_token_ids_hidden: None,
ignore_eos,
......@@ -321,6 +321,7 @@ pub trait DeltaGeneratorExt<ResponseType: Send + Sync + 'static + std::fmt::Debu
response: common::llm_backend::BackendOutput,
) -> Result<ResponseType>;
}
#[cfg(test)]
mod tests {
......
......@@ -13,775 +13,46 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::collections::VecDeque;
use std::fmt;
use std::fmt::Display;
use derive_builder::Builder;
use serde::de::{self, SeqAccess, Visitor};
use serde::ser::SerializeMap;
use super::nvext::NvExt;
use super::nvext::NvExtProvider;
use super::OpenAISamplingOptionsProvider;
use super::OpenAIStopConditionsProvider;
use serde::{Deserialize, Serialize};
use serde::{Deserializer, Serializer};
use serde_json::Value;
use triton_distributed_runtime::protocols::annotated::AnnotationsProvider;
use validator::Validate;
mod aggregator;
mod delta;
use super::nvext::NvExtProvider;
pub use super::{CompletionTokensDetails, CompletionUsage, PromptTokensDetails};
// pub use aggregator::DeltaAggregator;
pub use aggregator::DeltaAggregator;
pub use delta::DeltaGenerator;
use super::{
common::{self, ChatCompletionLogprobs, SamplingOptionsProvider, StopConditionsProvider},
nvext::NvExt,
validate_logit_bias, ContentProvider, OpenAISamplingOptionsProvider,
OpenAIStopConditionsProvider,
};
use triton_distributed_runtime::protocols::annotated::AnnotationsProvider;
/// Request object which is used to generate chat completions.
#[derive(Serialize, Deserialize, Builder, Validate, Debug, Clone)]
#[builder(build_fn(private, name = "build_internal", validate = "Self::validate"))]
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct ChatCompletionRequest {
/// Multi-turn chat messages.
///
/// NIM Compatibility:
/// Multi-turn chat models vary, some of which work with the OpenAI ChatGPT format, while others
/// will require `NvExt`.
pub messages: Vec<ChatCompletionMessage>,
/// Name of the model
#[builder(setter(into))]
pub model: String,
/// The maximum number of tokens that can be generated in the completion.
/// The token count of your prompt plus max_tokens cannot exceed the model's context length.
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(into, strip_option))]
#[validate(range(min = 1))]
pub max_tokens: Option<i32>,
/// The minimum number of tokens to generate. We ignore stop tokens until we see this many
/// tokens. Leave this None unless you are working on the pre-processor.
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(into, strip_option))]
pub min_tokens: Option<i32>,
/// If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only
/// server-sent events as they become available, with the stream terminated by a data: \[DONE\]
///
/// NIM Compatibility:
/// The NIM SDK can send extra meta data in the SSE stream using the `:` comment, `event:`,
/// or `id:` fields. See the `enable_sse_metadata` field in the NvExt object.
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub stream: Option<bool>,
/// How many chat completion choices to generate for each input message.
///
/// NIM Compatibility:
/// Values greater than 1 are not currently supported by NIM.
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(into, strip_option))]
pub n: Option<i32>,
/// What sampling `temperature` to use, between 0 and 2. Higher values like 0.8 will make the
/// output more random, while lower values like 0.2 will make it more focused and deterministic.
/// OpenAI defaults to 1.0; however, in this crate, the default is None, and model-specific defaults
/// can be applied later as part of associating the request with a given model.
///
/// OpenAI generally recommend altering this or `top_p` but not both.
///
/// TODO(): Add a model specific validation which could enforce only a single type of sampling can be used.
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = "super::MIN_TEMPERATURE", max = "super::MAX_TEMPERATURE"))]
#[builder(default, setter(into, strip_option))]
pub temperature: Option<f32>,
/// An alternative to sampling with `temperature`, called nucleus sampling, where the model
/// considers the results of the tokens with `top_p` probability mass. So 0.1 means only the tokens
/// comprising the top 10% probability mass are considered.
///
/// We generally recommend altering this or `temperature` but not both.
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = "super::MIN_TOP_P", max = "super::MAX_TOP_P"))]
#[builder(default, setter(into, strip_option))]
pub top_p: Option<f32>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency
/// in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(
min = "super::MIN_FREQUENCY_PENALTY",
max = "super::MAX_FREQUENCY_PENALTY"
))]
#[builder(default, setter(into, strip_option))]
pub frequency_penalty: Option<f32>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in
/// the text so far, increasing the model's likelihood to talk about new topics.
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(
min = "super::MIN_PRESENCE_PENALTY",
max = "super::MAX_PRESENCE_PENALTY"
))]
#[builder(default, setter(into, strip_option))]
pub presence_penalty: Option<f32>,
/// OpenAI specific API fields:
/// See: <https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format>
///
/// NIM Compatibility:
/// This option is not currently supported by NIM LLM. An error will be returned if this field is set.
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default)]
pub response_format: Option<Value>,
/// Up to 4 sequences where the API will stop generating further tokens.
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(length(max = 4))]
#[builder(default, setter(into, strip_option))]
pub stop: Option<Vec<String>>,
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities
/// of each output token returned in the content of message.
///
/// Not all models support logprobs. If logprobs is set to true for a model that does not support it,
/// the request will be processed as if logprobs is set to false.
///
/// NIM Compatibility:
/// TODO - Add a NvExt `strict` object which will disable relaxing of model specific limitations; meaning,
/// if the user requests `logprobs` and the model does not support them, the request will fail wth an error.
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub logprobs: Option<bool>,
/// An integer between 0 and 20 specifying the number of most likely tokens to return at each token position,
/// each with an associated log probability. logprobs must be set to true if this parameter is used.
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = 0, max = 20))]
#[builder(default, setter(into, strip_option))]
pub top_logprobs: Option<i32>,
/// Modify the likelihood of specified tokens appearing in the completion.
///
/// Accepts a JSON object that maps tokens (specified by their token ID in the GPT tokenizer) to an
/// associated bias value from -100 to 100. You can use this tokenizer tool to convert text to token IDs.
/// Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact
/// effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of
/// selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
///
/// As specified in the OpenAI examples, this is a map of tokens_ids as strings to a bias value that
/// is an integer.
///
/// However, the OpenAI blog using the SDK shows that it can also be specified more accurately as a
/// map of token_ids as ints to a bias value that is also an int.
///
/// NIM Compatibility:
/// In the conversion of the OpenAI request to the internal NIM format, the keys of this map will be
/// validated to ensure they are integers. Since different models may have different tokenizers, the
/// range and values will again be validated on the compute backend to ensure they map to valid tokens
/// in the vocabulary of the model.
///
/// ```
/// use triton_distributed_llm::protocols::openai::completions::CompletionRequest;
///
/// let request = CompletionRequest::builder()
/// .prompt("What is the meaning of life?")
/// .model("meta/llama-3.1-8b-instruct")
/// .add_logit_bias(1337, -100) // using an int as a key is ok
/// .add_logit_bias("42", 100) // using a string as a key is also ok
/// .build()
/// .expect("Should not fail");
///
/// assert!(CompletionRequest::builder()
/// .prompt("What is the meaning of life?")
/// .model("meta/llama-3.1-8b-instruct")
/// .add_logit_bias("some non int", -100)
/// .build()
/// .is_err());
/// ```
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(custom(function = "validate_logit_bias"))]
#[builder(default, setter(into, strip_option))]
pub logit_bias: Option<HashMap<String, i32>>,
/// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
///
/// NIM Compatibility:
/// If provided, then the value of this field will be included in the trace metadata and the accounting
/// data (if enabled).
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(into, strip_option))]
pub user: Option<String>,
/// If specified, our system will make a best effort to sample deterministically, such that repeated
/// requests with the same seed and parameters should return the same result. Determinism is not guaranteed,
/// and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend.
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(into, strip_option))]
pub seed: Option<i64>,
/// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to
/// provide a list of functions the model may generate JSON inputs for. A max of 128 functions are supported.
///
/// NIM Compatibility:
/// This field is not currently supported by NIM LLM. An error will be returned if this field is set.
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default)]
pub tools: Option<Vec<Tool>>,
/// Controls which (if any) function is called by the model. none means the model will not call a function
/// and instead generates a message. auto means the model can pick between generating a message or calling
/// a function. Specifying a particular function via {"type": "function", "function": {"name": "my_function"}}
/// forces the model to call that function.
///
/// `none` is the default when no functions are present. `auto` is the default if functions are present.
///
/// NIM Compatibility:
/// This field is not currently supported by NIM LLM. An error will be returned if this field is set.
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(serialize_with = "serialize_tool_choice")]
#[builder(default)]
pub tool_choice: Option<ToolChoiceType>,
/// Additional parameters supported by NIM backends
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
#[serde(flatten)]
pub inner: async_openai::types::CreateChatCompletionRequest,
pub nvext: Option<NvExt>,
}
impl ChatCompletionRequest {
pub fn builder() -> ChatCompletionRequestBuilder {
ChatCompletionRequestBuilder::default()
}
}
impl ChatCompletionRequestBuilder {
// This is a pre-build validate function
// This is called before the generated build method, in this case build_internal, is called
// This has access to the internal state of the builder
fn validate(&self) -> Result<(), String> {
Ok(())
}
/// Builds and validates the ChatCompletionRequest
///
/// ```rust
/// use triton_distributed_llm::protocols::openai::chat_completions::ChatCompletionRequest;
///
/// let request = ChatCompletionRequest::builder()
/// .model("mixtral-8x7b-instruct-v0.1")
/// .add_user_message("Hello")
/// .max_tokens(16)
/// .build()
/// .expect("Failed to build ChatCompletionRequest");
/// ```
pub fn build(&self) -> anyhow::Result<ChatCompletionRequest> {
// Calls the build_private, validates the result, then performs addition
// post build validation where we are looking a mutually exclusive fields
// and ensuring that there are not mutually exclusive collisions.
let request = self
.build_internal()
.map_err(|e| anyhow::anyhow!("Failed to build ChatCompletionRequest: {}", e))?;
request
.validate()
.map_err(|e| anyhow::anyhow!("Failed to validate ChatCompletionRequest: {}", e))?;
// check mutually exclusive fields
if request.top_logprobs.is_some() {
if request.logprobs.is_none() {
anyhow::bail!("top_logprobs requires logprobs to be set to true");
}
if let Some(logprobs) = request.logprobs {
if !logprobs {
anyhow::bail!("top_logprobs requires logprobs to be set to true");
}
}
}
Ok(request)
}
/// Add a message to the `Vec<ChatCompletionMessage>` in the ChatCompletionRequest
/// This will either create or append to the `Vec<ChatCompletionMessage>`
pub fn add_message(&mut self, message: ChatCompletionMessage) -> &mut Self {
// If messages exist we get them or we create new messages with Vec::new
self.messages.get_or_insert_with(Vec::new).push(message);
self
}
/// Add a user message to the `Vec<ChatCompletionMessage>` in the ChatCompletionRequest
pub fn add_user_message(&mut self, content: impl Into<String>) -> &mut Self {
self.add_message(ChatCompletionMessage {
role: MessageRole::user,
content: Content::Text(content.into()),
name: None,
})
}
/// Add an assistant message to the `Vec<ChatCompletionMessage>` in the ChatCompletionRequest
pub fn add_assistant_message(&mut self, content: impl Into<String>) -> &mut Self {
self.add_message(ChatCompletionMessage {
role: MessageRole::assistant,
content: Content::Text(content.into()),
name: None,
})
}
/// Add a system message to the `Vec<ChatCompletionMessage>` in the ChatCompletionRequest
pub fn add_system_message(&mut self, content: impl Into<String>) -> &mut Self {
self.add_message(ChatCompletionMessage {
role: MessageRole::system,
content: Content::Text(content.into()),
name: None,
})
}
/// Add a stop condition to the `Vec<String>` in the ChatCompletionRequest
/// This will either create or append to the `Vec<String>`
pub fn add_stop(&mut self, stop: impl Into<String>) -> &mut Self {
self.stop
.get_or_insert_with(|| Some(vec![]))
.as_mut()
.expect("stop should always be Some(Vec)")
.push(stop.into());
self
}
/// Add a token and bias to the `HashMap<String, i32>` in the ChatCompletionRequest
/// This will either create or update the `HashMap<String, i32>`
/// See: [`ChatCompletionRequest::logit_bias`] for more details
pub fn add_logit_bias<T>(&mut self, token_id: T, bias: i32) -> &mut Self
where
T: std::fmt::Display,
{
self.logit_bias
.get_or_insert_with(|| Some(HashMap::new()))
.as_mut()
.expect("logit_bias should always be Some(HashMap)")
.insert(token_id.to_string(), bias);
self
}
}
/// Each turn in a conversation is represented by a ChatCompletionMessage.
#[derive(Builder, Debug, Deserialize, Serialize, Clone)]
pub struct ChatCompletionMessage {
pub role: MessageRole,
#[serde(deserialize_with = "deserialize_content")]
pub content: Content,
#[serde(skip_serializing_if = "Option::is_none", default)]
#[builder(default)]
pub name: Option<String>,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[allow(non_camel_case_types)]
pub enum MessageRole {
user,
system,
assistant,
function,
}
impl Display for MessageRole {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
use MessageRole::*;
let s = match self {
user => "user",
system => "system",
assistant => "assistant",
function => "function",
};
write!(f, "{s}")
}
}
#[derive(Debug, Deserialize, Clone, PartialEq, Eq)]
pub enum Content {
Text(String),
ImageUrl(Vec<ImageUrl>),
}
impl serde::Serialize for Content {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match *self {
Content::Text(ref text) => serializer.serialize_str(text),
Content::ImageUrl(ref image_url) => image_url.serialize(serializer),
}
}
}
fn deserialize_content<'de, D>(deserializer: D) -> Result<Content, D::Error>
where
D: Deserializer<'de>,
{
struct ContentVisitor;
impl<'de> Visitor<'de> for ContentVisitor {
type Value = Content;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string or an array of content parts")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(Content::Text(value.to_owned()))
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut parts = Vec::new();
while let Some(value) = seq.next_element::<String>()? {
if value.starts_with("http://") || value.starts_with("https://") {
parts.push(ImageUrl {
r#type: ContentType::image_url,
text: None,
image_url: Some(ImageUrlType { url: value }),
});
} else {
parts.push(ImageUrl {
r#type: ContentType::text,
text: Some(value),
image_url: None,
});
}
}
Ok(Content::ImageUrl(parts))
}
}
deserializer.deserialize_any(ContentVisitor)
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[allow(non_camel_case_types)]
pub enum ContentType {
text,
image_url,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[allow(non_camel_case_types)]
pub struct ImageUrlType {
pub url: String,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[allow(non_camel_case_types)]
pub struct ImageUrl {
pub r#type: ContentType,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub image_url: Option<ImageUrlType>,
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct ChatCompletionResponse {
#[serde(flatten)]
pub inner: async_openai::types::CreateChatCompletionResponse,
}
/// Represents a chat completion response returned by model, based on the provided input.
pub type ChatCompletionResponse = ChatCompletionGeneric<ChatCompletionChoice>;
/// Represents a streamed chunk of a chat completion response returned by model, based on the provided input.
pub type ChatCompletionResponseDelta = ChatCompletionGeneric<ChatCompletionChoiceDelta>;
/// Common structure for chat completion responses; the only delta is the type of choices which differs
/// between streaming and non-streaming requests.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ChatCompletionGeneric<C>
where
C: Serialize + Clone + ContentProvider,
{
/// A unique identifier for the chat completion.
pub id: String,
/// A list of chat completion choices. Can be more than one if n is greater than 1.
pub choices: Vec<C>,
/// The Unix timestamp (in seconds) of when the chat completion was created.
pub created: u64,
/// The model used for the chat completion.
pub model: String,
/// The object type, which is `chat.completion` if the type of `Choice` is `ChatCompletionChoice`,
/// or is `chat.completion.chunk` if the type of `Choice` is `ChatCompletionChoiceDelta`.
pub object: String,
/// Usage information for the completion request.
pub usage: Option<CompletionUsage>,
/// The service tier used for processing the request, optional.
#[serde(skip_serializing_if = "Option::is_none")]
pub service_tier: Option<ServiceTier>,
/// This fingerprint represents the backend configuration that the model runs with.
///
/// Can be used in conjunction with the seed request parameter to understand when backend changes
/// have been made that might impact determinism.
///
/// NIM Compatibility:
/// This field is not supported by the NIM; however it will be added in the future.
/// The optional nature of this field will be relaxed when it is supported.
pub system_fingerprint: Option<String>,
// TODO() - add NvResponseExtention
}
// Enum for service tier, either "scale" or "default"
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "snake_case")]
pub enum ServiceTier {
Auto,
Scale,
Default,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct ChatCompletionChoice {
/// A chat completion message generated by the model.
pub message: ChatCompletionContent,
/// The index of the choice in the list of choices.
pub index: u64,
/// The reason the model stopped generating tokens. This will be `stop` if the model hit a natural
/// stop point or a provided stop sequence, `length` if the maximum number of tokens specified
/// in the request was reached, `content_filter` if content was omitted due to a flag from our content
/// filters, `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called
/// a function.
///
/// NIM Compatibility:
/// Only `stop` and `length` are currently supported by NIM.
/// NIM may also provide additional reasons in the future, such as `error`, `timeout` or `cancelation`.
pub finish_reason: FinishReason,
/// Log probability information for the choice, optional field.
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatCompletionLogprobs>,
}
impl ContentProvider for ChatCompletionChoice {
fn content(&self) -> String {
self.message.content()
}
}
/// Same as ChatCompletionMessage, but received during a response stream.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ChatCompletionChoiceDelta {
/// The index of the choice in the list of choices.
pub index: u64,
/// The reason the model stopped generating tokens. This will be `stop` if the model hit a natural
/// stop point or a provided stop sequence, `length` if the maximum number of tokens specified
/// in the request was reached, `content_filter` if content was omitted due to a flag from our content
/// filters, `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called
/// a function.
///
/// NIM Compatibility:
/// Only `stop` and `length` are currently supported by NIM.
/// NIM may also provide additional reasons in the future, such as `error`, `timeout` or `cancelation`.
pub finish_reason: Option<FinishReason>,
/// A chat completion delta generated by streamed model responses.
pub delta: ChatCompletionContent,
/// Log probability information for the choice, optional field.
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatCompletionLogprobs>,
}
impl ContentProvider for ChatCompletionChoiceDelta {
fn content(&self) -> String {
self.delta.content()
}
}
/// A chat completion message generated by the model.
#[derive(Clone, Debug, Deserialize, Serialize)]
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct ChatCompletionContent {
/// The role of the author of this message.
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<MessageRole>,
/// The contents of the message.
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
/// Tool calls made by the model.
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
}
impl ContentProvider for ChatCompletionContent {
fn content(&self) -> String {
self.content.clone().unwrap_or("".to_string())
}
#[serde(flatten)]
pub inner: async_openai::types::ChatCompletionStreamResponseDelta,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
pub enum ToolChoiceType {
None,
Auto,
ToolChoice { tool: Tool },
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct ChatCompletionResponseDelta {
#[serde(flatten)]
pub inner: async_openai::types::CreateChatCompletionStreamResponse,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
pub struct Function {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub parameters: FunctionParameters,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum JSONSchemaType {
Object,
Number,
String,
Array,
Null,
Boolean,
}
#[derive(Debug, Deserialize, Serialize, Clone, Default, PartialEq, Eq)]
pub struct JSONSchemaDefine {
#[serde(rename = "type")]
pub schema_type: Option<JSONSchemaType>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub enum_values: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub properties: Option<HashMap<String, Box<JSONSchemaDefine>>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub items: Option<Box<JSONSchemaDefine>>,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
pub struct FunctionParameters {
#[serde(rename = "type")]
pub schema_type: JSONSchemaType,
#[serde(skip_serializing_if = "Option::is_none")]
pub properties: Option<HashMap<String, Box<JSONSchemaDefine>>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<Vec<String>>,
}
#[derive(Clone, Copy, Debug, Deserialize, Serialize, PartialEq, Eq)]
#[allow(non_camel_case_types)]
pub enum FinishReason {
stop,
length,
content_filter,
tool_calls,
cancelled,
null,
}
/// from_str trait
impl std::str::FromStr for FinishReason {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"stop" => Ok(FinishReason::stop),
"length" => Ok(FinishReason::length),
"content_filter" => Ok(FinishReason::content_filter),
"tool_calls" => Ok(FinishReason::tool_calls),
"null" => Ok(FinishReason::null),
_ => Err(format!("Unknown FinishReason: {}", s)),
}
}
}
impl std::fmt::Display for FinishReason {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
FinishReason::stop => write!(f, "stop"),
FinishReason::length => write!(f, "length"),
FinishReason::content_filter => write!(f, "content_filter"),
FinishReason::tool_calls => write!(f, "tool_calls"),
FinishReason::cancelled => write!(f, "cancelled"),
FinishReason::null => write!(f, "null"),
}
}
}
#[derive(Debug, Deserialize, Serialize)]
#[allow(non_camel_case_types)]
pub struct FinishDetails {
pub r#type: FinishReason,
pub stop: String,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct ToolCall {
pub id: String,
pub r#type: String,
pub function: ToolCallFunction,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct ToolCallFunction {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
}
fn serialize_tool_choice<S>(
value: &Option<ToolChoiceType>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match value {
Some(ToolChoiceType::None) => serializer.serialize_str("none"),
Some(ToolChoiceType::Auto) => serializer.serialize_str("auto"),
Some(ToolChoiceType::ToolChoice { tool }) => {
let mut map = serializer.serialize_map(Some(2))?;
map.serialize_entry("type", &tool.r#type)?;
map.serialize_entry("function", &tool.function)?;
map.end()
}
None => serializer.serialize_none(),
}
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
pub struct Tool {
pub r#type: ToolType,
pub function: Function,
}
#[derive(Debug, Deserialize, Serialize, Copy, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ToolType {
Function,
}
impl ChatCompletionRequest {}
impl NvExtProvider for ChatCompletionRequest {
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
......@@ -810,19 +81,19 @@ impl AnnotationsProvider for ChatCompletionRequest {
impl OpenAISamplingOptionsProvider for ChatCompletionRequest {
fn get_temperature(&self) -> Option<f32> {
self.temperature
self.inner.temperature
}
fn get_top_p(&self) -> Option<f32> {
self.top_p
self.inner.top_p
}
fn get_frequency_penalty(&self) -> Option<f32> {
self.frequency_penalty
self.inner.frequency_penalty
}
fn get_presence_penalty(&self) -> Option<f32> {
self.presence_penalty
self.inner.presence_penalty
}
fn nvext(&self) -> Option<&NvExt> {
......@@ -830,815 +101,26 @@ impl OpenAISamplingOptionsProvider for ChatCompletionRequest {
}
}
#[allow(deprecated)]
impl OpenAIStopConditionsProvider for ChatCompletionRequest {
fn get_max_tokens(&self) -> Option<i32> {
self.max_tokens
fn get_max_tokens(&self) -> Option<u32> {
// ALLOW: max_tokens is deprecated in favor of max_completion_tokens
self.inner.max_tokens
}
fn get_min_tokens(&self) -> Option<i32> {
self.min_tokens
fn get_min_tokens(&self) -> Option<u32> {
// TODO THIS IS WRONG min_tokens does not exist
None
}
fn get_stop(&self) -> Option<Vec<String>> {
self.stop.clone()
// TODO THIS IS WRONG should instead do
// Vec<String> -> async_openai::types::Stop
// self.inner.stop.clone()
None
}
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
}
/// Implements TryFrom for converting an OpenAI's ChatCompletionRequest to an Engine's CompletionRequest
impl TryFrom<ChatCompletionRequest> for common::CompletionRequest {
type Error = anyhow::Error;
fn try_from(request: ChatCompletionRequest) -> Result<Self, Self::Error> {
// openai_api_rs::v1::chat_completion
// pub struct ChatCompletionRequest {
// NA pub model: String,
// L pub messages: Vec<ChatCompletionMessage, Global>,
// SO pub temperature: Option<f32>,
// SO pub top_p: Option<f32>,
// SO pub n: Option<i32>,
// ** pub response_format: Option<Value>,
// NA pub stream: Option<bool>, // See Issue #8
// SC pub stop: Option<Vec<String, Global>>,
// SC pub max_tokens: Option<i32>,
// SO pub presence_penalty: Option<f32>,
// SO pub frequency_penalty: Option<f32>,
// ** pub logit_bias: Option<HashMap<String, i32, RandomState>>,
// ** pub user: Option<String>,
// SO pub seed: Option<i64>,
// ** pub tools: Option<Vec<Tool, Global>>,
// ** pub tool_choice: Option<ToolChoiceType>,
// }
//
// ** not supported
// NA not applicable
// L local in this method
// SO extract_sampling_options
// SC extract_stop_conditions
// first we validate the OpenAI request
// we can not validate everything as some fields require backend awareness
// however, we can validate against the public OpenAI limit
request
.validate()
.map_err(|e| anyhow::anyhow!("Failed to validate ChatCompletionRequest: {}", e))?;
// todo(ryan) - open a ticket to support this
if request.logit_bias.is_some() {
anyhow::bail!("logit_bias is not supported");
}
// todo(ryan) - add support for user
if request.user.is_some() {
anyhow::bail!("user is not supported");
}
if request.response_format.is_some() {
anyhow::bail!("response_format is not supported");
}
if request.tools.is_some() {
anyhow::bail!("tools is not supported");
}
if request.tool_choice.is_some() {
anyhow::bail!("tool_choice is not supported");
}
// sampling options
let sampling_options = request
.extract_sampling_options()
.map_err(|e| anyhow::anyhow!("Failed to extract SamplingOptions: {}", e))?;
// stop conditions
let stop_conditions = request
.extract_stop_conditions()
.map_err(|e| anyhow::anyhow!("Failed to extract StopConditions: {}", e))?;
// first we need to process the messages
let prompt = common::PromptType::ChatCompletion(
validate_and_collect_chat_messages(request.messages)
.map_err(|e| anyhow::anyhow!("Failed to validate chat messages: {}", e))?,
);
// return the completion request
Ok(common::CompletionRequest {
prompt,
stop_conditions,
sampling_options,
mdc_sum: None,
annotations: None,
})
}
}
impl TryFrom<common::StreamingCompletionResponse> for ChatCompletionChoice {
type Error = anyhow::Error;
fn try_from(response: common::StreamingCompletionResponse) -> Result<Self, Self::Error> {
let choice = ChatCompletionChoice {
index: response.delta.index.unwrap_or(0) as u64,
message: ChatCompletionContent {
role: Some(MessageRole::assistant),
content: response.delta.text,
tool_calls: None,
},
finish_reason: match &response.delta.finish_reason {
Some(common::FinishReason::EoS) => FinishReason::stop,
Some(common::FinishReason::Stop) => FinishReason::stop,
Some(common::FinishReason::Length) => FinishReason::length,
Some(common::FinishReason::Error(err_msg)) => {
return Err(anyhow::anyhow!("finish_reason::error = {}", err_msg));
}
Some(common::FinishReason::Cancelled) => FinishReason::null,
None => FinishReason::null,
},
logprobs: response.logprobs,
};
Ok(choice)
}
}
impl TryFrom<common::StreamingCompletionResponse> for ChatCompletionChoiceDelta {
type Error = anyhow::Error;
fn try_from(response: common::StreamingCompletionResponse) -> Result<Self, Self::Error> {
let choice = ChatCompletionChoiceDelta {
index: response.delta.index.unwrap_or(0) as u64,
delta: ChatCompletionContent {
role: Some(MessageRole::assistant),
content: response.delta.text,
tool_calls: None,
},
finish_reason: match &response.delta.finish_reason {
Some(common::FinishReason::EoS) => Some(FinishReason::stop),
Some(common::FinishReason::Stop) => Some(FinishReason::stop),
Some(common::FinishReason::Length) => Some(FinishReason::length),
Some(common::FinishReason::Error(err_msg)) => {
return Err(anyhow::anyhow!("finish_reason::error = {}", err_msg));
}
Some(common::FinishReason::Cancelled) => Some(FinishReason::null),
None => None,
},
logprobs: response.logprobs,
};
Ok(choice)
}
}
fn validate_and_collect_chat_messages(
messages: Vec<ChatCompletionMessage>,
) -> Result<common::ChatContext, anyhow::Error> {
let mut system_prompt = None;
let mut turns = VecDeque::new();
let mut last_role = MessageRole::assistant;
for message in messages {
match message.role {
MessageRole::system => {
if system_prompt.is_some() {
return Err(anyhow::anyhow!("More than one system message found"));
}
system_prompt = Some(message.content);
}
MessageRole::user | MessageRole::assistant => {
if last_role == message.role {
if turns.is_empty() {
return Err(anyhow::anyhow!("First message must be a user message"));
}
return Err(anyhow::anyhow!(
"User and assistant messages must alternate"
));
}
last_role = message.role.clone();
turns.push_back(message);
}
MessageRole::function => {} // Ignoring function messages as per assumption.
}
}
if let Some(first) = turns.front() {
if let MessageRole::assistant = first.role {
return Err(anyhow::anyhow!("Sequence must start with a user message"));
}
}
if turns.len() % 2 == 0 {
return Err(anyhow::anyhow!("Sequence must end with a user message"));
}
let mut context = Vec::new();
while turns.len() >= 2 {
let user = turns.pop_front().unwrap();
let asst = turns.pop_front().unwrap();
let user = match user.content {
Content::Text(text) => text,
_ => return Err(anyhow::anyhow!("User message must be text")),
};
let asst = match asst.content {
Content::Text(text) => text,
_ => return Err(anyhow::anyhow!("Assistant message must be text")),
};
context.push(common::ChatTurn {
user,
assistant: asst,
});
}
let prompt = turns.pop_back().unwrap();
let prompt = match prompt.content {
Content::Text(text) => text,
_ => return Err(anyhow::anyhow!("Prompt message must be text")),
};
let system_prompt = match system_prompt {
Some(Content::Text(text)) => Some(text),
Some(_) => return Err(anyhow::anyhow!("System prompt must be text")),
None => None,
};
Ok(common::ChatContext {
completion: common::CompletionContext {
prompt,
system_prompt,
},
context,
})
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use serde_json::json;
use std::error::Error;
use super::*;
#[test]
fn test_chat_completions_valid_request_minimal() -> Result<(), Box<dyn Error>> {
let request = ChatCompletionRequest::builder()
.model("meta/llama-3.1-8b-instruct")
.add_user_message("Hello!")
.build();
assert!(
request.is_ok(),
"Request should succeed with minimal fields"
);
Ok(())
}
#[test]
fn test_chat_completions_valid_request_full() -> Result<(), Box<dyn Error>> {
let request = ChatCompletionRequest::builder()
.model("meta/llama-3.1-8b-instruct")
.add_user_message("Hello!")
.max_tokens(50)
.stream(true)
.n(1)
.temperature(1.0)
.top_p(0.9)
.frequency_penalty(0.5)
.presence_penalty(0.5)
.stop(vec!["The end.".to_string()])
.logprobs(true)
.top_logprobs(5)
.logit_bias(HashMap::new())
.user("test_user")
.seed(1234)
.build();
println!("{:?}", request);
assert!(
request.is_ok(),
"Request should succeed with all fields set"
);
Ok(())
}
#[test]
fn test_chat_completions_top_logprobs_requires_logprobs() -> Result<(), Box<dyn Error>> {
let request = ChatCompletionRequest::builder()
.model("meta/llama-3.1-8b-instruct")
.add_user_message("Hello!")
.top_logprobs(5) // logprobs is not set to true
.build();
assert!(
request.is_err(),
"Request should fail when top_logprobs is set without logprobs being true"
);
Ok(())
}
#[ignore]
#[test]
fn test_chat_completions_max_tokens_out_of_range() -> Result<(), Box<dyn Error>> {
let request = ChatCompletionRequest::builder()
.model("meta/llama-3.1-8b-instruct")
.add_user_message("Hello!")
.max_tokens(4097) // assuming the model has a max context length of 4096
.build();
assert!(
request.is_err(),
"Request should fail when max_tokens exceeds model's context length"
);
Ok(())
}
#[test]
fn test_chat_completions_invalid_top_p() -> Result<(), Box<dyn Error>> {
let request = ChatCompletionRequest::builder()
.model("meta/llama-3.1-8b-instruct")
.add_user_message("Hello!")
.top_p(1.5) // Invalid, should be between 0 and 1
.build();
assert!(
request.is_err(),
"Request should fail with invalid top_p value"
);
Ok(())
}
#[test]
fn test_chat_completions_missing_messages() -> Result<(), Box<dyn Error>> {
// Missing messages field in the request
let request_result = ChatCompletionRequest::builder()
.model("meta/llama-3.1-8b-instruct") // Valid model
.build(); // This should fail because no messages are provided.
assert!(
request_result.is_err(),
"Expected request to fail without messages."
);
if let Err(e) = request_result {
println!("Expected error: {}", e); // Optionally print the error for debugging
}
Ok(())
}
#[test]
fn test_chat_completions_negative_max_tokens() -> Result<(), Box<dyn Error>> {
let request = ChatCompletionRequest::builder()
.model("meta/llama-3.1-8b-instruct")
.add_user_message("Hello, world!")
.max_tokens(-10)
.build();
assert!(
request.is_err(),
"Request should fail with negative max_tokens"
);
Ok(())
}
#[ignore]
#[test]
fn test_chat_completions_unsupported_logit_bias() -> Result<(), Box<dyn Error>> {
let request = ChatCompletionRequest::builder()
.model("meta/llama-3.1-8b-instruct")
.add_user_message("Hello, world!")
.add_logit_bias("50256", -100)
.build();
assert!(request.is_err(), "Request should fail with logit_bias");
Ok(())
}
#[test]
fn test_chat_completions_invalid_temperature() -> Result<(), Box<dyn Error>> {
let request = ChatCompletionRequest::builder()
.model("meta/llama-3.1-8b-instruct")
.add_user_message("Hello!")
.temperature(2.5) // Invalid, should be between 0 and 2
.build();
assert!(
request.is_err(),
"Request should fail with invalid temperature"
);
Ok(())
}
#[test]
fn test_chat_completions_max_stop_sequences() -> Result<(), Box<dyn Error>> {
let request = ChatCompletionRequest::builder()
.model("meta/llama-3.1-8b-instruct")
.add_user_message("Tell me a story.")
.stop(vec![
"The end.".to_string(),
"Once upon a time,".to_string(),
"And then,".to_string(),
"They lived happily ever after.".to_string(),
]) // 4 stop sequences, valid
.build();
assert!(
request.is_ok(),
"Request should succeed with 4 stop sequences"
);
Ok(())
}
#[test]
fn test_chat_completions_large_stop_sequences() -> Result<(), Box<dyn Error>> {
let request = ChatCompletionRequest::builder()
.model("meta/llama-3.1-8b-instruct")
.add_user_message("Tell me a story.")
.stop(vec![
"The end.".to_string(),
"And so,".to_string(),
"Once upon a time,".to_string(),
"They lived happily ever after.".to_string(),
"Unexpected stop.".to_string(),
])
.build();
assert!(
request.is_err(),
"Request should fail with too many stop sequences"
);
Ok(())
}
#[ignore]
#[test]
fn test_chat_completions_invalid_stop_sequences() -> Result<(), Box<dyn Error>> {
let request = ChatCompletionRequest::builder()
.model("meta/llama-3.1-8b-instruct")
.add_user_message("Tell me a joke.")
.stop(vec!["".to_string()])
.build();
assert!(
request.is_err(),
"Request should fail with invalid stop sequences"
);
Ok(())
}
#[ignore]
#[test]
fn test_chat_completions_presence_penalty_out_of_range() -> Result<(), Box<dyn Error>> {
let request = ChatCompletionRequest::builder()
.model("meta/llama-3.1-8b-instruct")
.add_user_message("What's up?")
.presence_penalty(3.0) // Out of valid range (-2.0 to 2.0)
.build();
assert!(
request.is_err(),
"Request should fail with invalid presence_penalty"
);
Ok(())
}
#[test]
fn test_chat_completions_invalid_presence_penalty() -> Result<(), Box<dyn Error>> {
let request = ChatCompletionRequest::builder()
.model("meta/llama-3.1-8b-instruct")
.add_user_message("What's up?")
.presence_penalty(-2.5) // Invalid, should be between -2.0 and 2.0
.build();
assert!(
request.is_err(),
"Request should fail with invalid presence_penalty"
);
Ok(())
}
#[ignore]
#[tokio::test]
async fn test_chat_completions_with_user_field() -> Result<(), Box<dyn Error>> {
let request = ChatCompletionRequest::builder()
.model("meta/llama-3.1-8b-instruct")
.add_user_message("Hi there!")
.user("test_user")
.build()
.unwrap();
// assert!(request.is_err(), "Request should fail with 'user' field");
let result: Result<common::CompletionRequest> = request.try_into();
assert!(
result.is_err(),
"Conversion should fail with 'user' field set",
);
Ok(())
}
#[test]
fn test_chat_completions_valid_with_seed() -> Result<(), Box<dyn Error>> {
let request = ChatCompletionRequest::builder()
.model("meta/llama-3.1-8b-instruct")
.add_user_message("Repeatable result")
.seed(12345)
.build();
assert!(
request.is_ok(),
"Request should succeed with seed value for determinism"
);
Ok(())
}
#[test]
fn test_validate_chat_messages_multiple_system_messages() -> Result<(), Box<dyn Error>> {
let request = ChatCompletionRequest::builder()
.model("test-model")
.add_system_message("System message 1")
.add_system_message("System message 2")
.add_user_message("Hello!")
.build()?;
let result = validate_and_collect_chat_messages(request.messages.clone());
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(e.to_string(), "More than one system message found");
}
Ok(())
}
#[test]
fn test_validate_chat_messages_user_messages_do_not_alternate() -> Result<(), Box<dyn Error>> {
let request = ChatCompletionRequest::builder()
.model("test-model")
.add_user_message("Hello!")
.add_user_message("How are you?")
.build()?;
let result = validate_and_collect_chat_messages(request.messages.clone());
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(e.to_string(), "User and assistant messages must alternate");
}
Ok(())
}
#[ignore]
#[test]
fn test_validate_chat_messages_user_message_not_text() -> Result<(), Box<dyn Error>> {
let message = ChatCompletionMessage {
role: MessageRole::user,
content: Content::ImageUrl(vec![ImageUrl {
r#type: ContentType::image_url,
text: None,
image_url: Some(ImageUrlType {
url: "http://example.com/image.png".to_string(),
}),
}]),
name: None,
};
let request = ChatCompletionRequest::builder()
.model("test-model")
.add_message(message)
.build()?;
let result = validate_and_collect_chat_messages(request.messages.clone());
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(e.to_string(), "Generic error: User message must be text");
}
Ok(())
}
#[test]
fn test_try_from_chat_completion_request_with_unsupported_fields() -> Result<(), Box<dyn Error>>
{
let request = ChatCompletionRequest::builder()
.model("test-model")
.add_user_message("Hello!")
.response_format(Some(json!({"format": "unsupported"})))
.tools(Some(vec![Tool {
r#type: ToolType::Function,
function: Function {
name: "test_function".to_string(),
description: None,
parameters: FunctionParameters {
schema_type: JSONSchemaType::Object,
properties: None,
required: None,
},
},
}]))
.tool_choice(Some(ToolChoiceType::Auto))
.build()?;
let result: Result<common::CompletionRequest> = request.try_into();
assert!(
result.is_err(),
"Conversion should fail with unsupported fields"
);
Ok(())
}
#[test]
fn test_deserialize_content_with_image_urls() {
let json_data = r#"
{
"role": "assistant",
"content": [
"This is a text message.",
"https://example.com/image1.png",
"Another text message.",
"https://example.com/image2.png"
]
}
"#;
let message: ChatCompletionMessage =
serde_json::from_str(json_data).expect("Deserialization failed");
if let Content::ImageUrl(parts) = message.content {
assert_eq!(parts.len(), 4);
assert_eq!(parts[0].r#type, ContentType::text);
assert_eq!(parts[0].text.as_ref().unwrap(), "This is a text message.");
assert_eq!(parts[1].r#type, ContentType::image_url);
assert_eq!(
parts[1].image_url.as_ref().unwrap().url,
"https://example.com/image1.png"
);
} else {
panic!("Expected Content::ImageUrl");
}
}
#[test]
fn test_try_from_chat_completion_request_success() -> Result<(), Box<dyn Error>> {
let request = ChatCompletionRequest::builder()
.model("test-model")
.add_user_message("Hello!")
.add_assistant_message("Hi there!")
.add_user_message("How are you?")
.build()?;
let completion_request: common::CompletionRequest = request.try_into()?;
assert!(matches!(
completion_request.prompt,
common::PromptType::ChatCompletion(_)
));
Ok(())
}
#[test]
fn test_chat_completion_sampling_params_with_valid_nvext() {
let nvext = NvExt {
ignore_eos: Some(true),
repetition_penalty: Some(0.6),
top_k: Some(3),
use_raw_prompt: None,
greed_sampling: None,
annotations: None,
};
let request = ChatCompletionRequest::builder()
.nvext(nvext)
.model("foo")
.add_system_message("Hello!")
.build()
.expect("Failed to build request with valid nvext");
assert_eq!(request.nvext.as_ref().unwrap().ignore_eos, Some(true));
assert_eq!(
request.nvext.as_ref().unwrap().repetition_penalty,
Some(0.6)
);
assert_eq!(request.nvext.as_ref().unwrap().top_k, Some(3));
}
#[test]
fn test_completion_sampling_params_without_nvext() {
let request = ChatCompletionRequest::builder()
.model("foo")
.add_user_message("Test")
.build()
.unwrap();
assert_eq!(request.frequency_penalty, None);
assert_eq!(request.logprobs, None);
}
#[test]
fn test_completion_sampling_params_with_valid_nvext() {
let nvext = NvExt {
ignore_eos: Some(true),
repetition_penalty: Some(0.6),
top_k: Some(3),
..Default::default()
};
let request = ChatCompletionRequest::builder()
.nvext(nvext)
.model("foo")
.add_user_message("Test")
.build()
.expect("Failed to build request with valid nvext");
assert_eq!(request.nvext.as_ref().unwrap().ignore_eos, Some(true));
assert_eq!(
request.nvext.as_ref().unwrap().repetition_penalty,
Some(0.6)
);
assert_eq!(request.nvext.as_ref().unwrap().top_k, Some(3));
}
// #[test]
// fn test_normalize_unicode_characters() {
// let str = "Hello there how are you\u{E0020}?".to_string();
// let normalized = str.sanitize_text();
// assert_eq!(normalized, "Hello there how are you?");
// }
// #[tokio::test]
// async fn test_chat_completion_request_filtered() {
// // Define input messages with Unicode character to filter
// let messages = vec![
// ChatCompletionMessage {
// role: MessageRole::user,
// content: Content::Text(
// "Hello there how are you\u{E0020}?"
// .to_string()
// .normalize_unicode_characters(),
// ),
// name: None,
// },
// ChatCompletionMessage {
// role: MessageRole::assistant,
// content: Content::Text("How may I help you?".to_string()),
// name: None,
// },
// ChatCompletionMessage {
// role: MessageRole::user,
// content: Content::Text("Do something for me?".to_string()),
// name: None,
// },
// ];
// // Define expected filtered messages
// let expected = vec![
// ChatCompletionMessage {
// role: MessageRole::user,
// content: Content::Text("Hello there how are you?".to_string()),
// name: None,
// },
// ChatCompletionMessage {
// role: MessageRole::assistant,
// content: Content::Text("How may I help you?".to_string()),
// name: None,
// },
// ChatCompletionMessage {
// role: MessageRole::user,
// content: Content::Text("Do something for me?".to_string()),
// name: None,
// },
// ];
// // Build ChatCompletionRequest with filtering applied
// let request = ChatCompletionRequest::builder()
// .model("foo")
// .messages(messages)
// .build()
// .expect("Failed to build ChatCompletionRequest");
// // Validate each message matches the expected filtered content
// for (i, message) in request.messages.iter().enumerate() {
// assert_eq!(message.role, expected[i].role);
// if let Content::Text(ref content) = message.content {
// if let Content::Text(ref expected_content) = expected[i].content {
// assert_eq!(content, expected_content);
// }
// }
// }
// }
}
......@@ -13,13 +13,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use super::{
ChatCompletionChoice, ChatCompletionContent, ChatCompletionResponse,
ChatCompletionResponseDelta, CompletionUsage, FinishReason, MessageRole, ServiceTier,
};
use super::{ChatCompletionResponse, ChatCompletionResponseDelta};
use crate::protocols::{
codec::{Message, SseCodecError},
common::ChatCompletionLogprobs,
convert_sse_stream, Annotated,
};
......@@ -32,21 +28,21 @@ type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>;
pub struct DeltaAggregator {
id: String,
model: String,
created: u64,
usage: Option<CompletionUsage>,
created: u32,
usage: Option<async_openai::types::CompletionUsage>,
system_fingerprint: Option<String>,
choices: HashMap<u64, DeltaChoice>,
choices: HashMap<u32, DeltaChoice>,
error: Option<String>,
service_tier: Option<ServiceTier>,
service_tier: Option<async_openai::types::ServiceTierResponse>,
}
// Holds the accumulated state of a choice
struct DeltaChoice {
index: u64,
index: u32,
text: String,
role: Option<MessageRole>,
finish_reason: Option<FinishReason>,
logprobs: Option<ChatCompletionLogprobs>,
role: Option<async_openai::types::Role>,
finish_reason: Option<async_openai::types::FinishReason>,
logprobs: Option<async_openai::types::ChatChoiceLogprobs>,
}
impl Default for DeltaAggregator {
......@@ -92,19 +88,19 @@ impl DeltaAggregator {
// TODO(#14) - Aggregate Annotation
let delta = delta.data.unwrap();
aggregator.id = delta.id;
aggregator.model = delta.model;
aggregator.created = delta.created;
aggregator.service_tier = delta.service_tier;
if let Some(usage) = delta.usage {
aggregator.id = delta.inner.id;
aggregator.model = delta.inner.model;
aggregator.created = delta.inner.created;
aggregator.service_tier = delta.inner.service_tier;
if let Some(usage) = delta.inner.usage {
aggregator.usage = Some(usage);
}
if let Some(system_fingerprint) = delta.system_fingerprint {
if let Some(system_fingerprint) = delta.inner.system_fingerprint {
aggregator.system_fingerprint = Some(system_fingerprint);
}
// handle the choices
for choice in delta.choices {
for choice in delta.inner.choices {
let state_choice =
aggregator
.choices
......@@ -141,12 +137,12 @@ impl DeltaAggregator {
let mut choices: Vec<_> = aggregator
.choices
.into_values()
.map(ChatCompletionChoice::from)
.map(async_openai::types::ChatChoice::from)
.collect();
choices.sort_by(|a, b| a.index.cmp(&b.index));
Ok(ChatCompletionResponse {
let inner = async_openai::types::CreateChatCompletionResponse {
id: aggregator.id,
created: aggregator.created,
usage: aggregator.usage,
......@@ -155,21 +151,30 @@ impl DeltaAggregator {
system_fingerprint: aggregator.system_fingerprint,
choices,
service_tier: aggregator.service_tier,
})
};
let response = ChatCompletionResponse { inner };
Ok(response)
}
}
// todo - handle tool calls
impl From<DeltaChoice> for ChatCompletionChoice {
#[allow(deprecated)]
impl From<DeltaChoice> for async_openai::types::ChatChoice {
fn from(delta: DeltaChoice) -> Self {
ChatCompletionChoice {
message: ChatCompletionContent {
role: delta.role,
// ALLOW: function_call is deprecated
async_openai::types::ChatChoice {
message: async_openai::types::ChatCompletionResponseMessage {
role: delta.role.expect("delta should have a Role"),
content: Some(delta.text),
tool_calls: None,
refusal: None,
function_call: None,
audio: None,
},
index: delta.index,
finish_reason: delta.finish_reason.unwrap_or(FinishReason::length),
finish_reason: delta.finish_reason,
logprobs: delta.logprobs,
}
}
......@@ -192,37 +197,47 @@ impl ChatCompletionResponse {
#[cfg(test)]
mod tests {
use crate::protocols::openai::chat_completions::ChatCompletionChoiceDelta;
use super::*;
use futures::stream;
#[allow(deprecated)]
fn create_test_delta(
index: u64,
index: u32,
text: &str,
role: Option<MessageRole>,
finish_reason: Option<FinishReason>,
role: Option<async_openai::types::Role>,
finish_reason: Option<async_openai::types::FinishReason>,
) -> Annotated<ChatCompletionResponseDelta> {
// ALLOW: function_call is deprecated
let delta = async_openai::types::ChatCompletionStreamResponseDelta {
content: Some(text.to_string()),
function_call: None,
tool_calls: None,
role,
refusal: None,
};
let choice = async_openai::types::ChatChoiceStream {
index,
delta,
finish_reason,
logprobs: None,
};
let inner = async_openai::types::CreateChatCompletionStreamResponse {
id: "test_id".to_string(),
model: "meta/llama-3.1-8b-instruct".to_string(),
created: 1234567890,
service_tier: None,
usage: None,
system_fingerprint: None,
choices: vec![choice],
object: "chat.completion".to_string(),
};
let data = ChatCompletionResponseDelta { inner };
Annotated {
data: Some(ChatCompletionResponseDelta {
id: "test_id".to_string(),
model: "meta/llama-3.1-8b-instruct".to_string(),
created: 1234567890,
service_tier: None,
usage: None,
system_fingerprint: None,
choices: vec![ChatCompletionChoiceDelta {
index,
delta: ChatCompletionContent {
role,
content: Some(text.to_string()),
tool_calls: None,
},
finish_reason,
logprobs: None,
}],
object: "chat.completion".to_string(),
}),
data: Some(data),
id: Some("test_id".to_string()),
event: None,
comment: None,
......@@ -242,19 +257,20 @@ mod tests {
let response = result.unwrap();
// Verify that the response is empty and has default values
assert_eq!(response.id, "");
assert_eq!(response.model, "");
assert_eq!(response.created, 0);
assert!(response.usage.is_none());
assert!(response.system_fingerprint.is_none());
assert_eq!(response.choices.len(), 0);
assert!(response.service_tier.is_none());
assert_eq!(response.inner.id, "");
assert_eq!(response.inner.model, "");
assert_eq!(response.inner.created, 0);
assert!(response.inner.usage.is_none());
assert!(response.inner.system_fingerprint.is_none());
assert_eq!(response.inner.choices.len(), 0);
assert!(response.inner.service_tier.is_none());
}
#[tokio::test]
async fn test_single_delta() {
// Create a sample delta
let annotated_delta = create_test_delta(0, "Hello,", Some(MessageRole::user), None);
let annotated_delta =
create_test_delta(0, "Hello,", Some(async_openai::types::Role::User), None);
// Create a stream
let stream = Box::pin(stream::iter(vec![annotated_delta]));
......@@ -267,18 +283,18 @@ mod tests {
let response = result.unwrap();
// Verify the response fields
assert_eq!(response.id, "test_id");
assert_eq!(response.model, "meta/llama-3.1-8b-instruct");
assert_eq!(response.created, 1234567890);
assert!(response.usage.is_none());
assert!(response.system_fingerprint.is_none());
assert_eq!(response.choices.len(), 1);
let choice = &response.choices[0];
assert_eq!(response.inner.id, "test_id");
assert_eq!(response.inner.model, "meta/llama-3.1-8b-instruct");
assert_eq!(response.inner.created, 1234567890);
assert!(response.inner.usage.is_none());
assert!(response.inner.system_fingerprint.is_none());
assert_eq!(response.inner.choices.len(), 1);
let choice = &response.inner.choices[0];
assert_eq!(choice.index, 0);
assert_eq!(choice.message.content.as_ref().unwrap(), "Hello,");
assert_eq!(choice.finish_reason, FinishReason::length);
assert_eq!(choice.message.role.as_ref().unwrap(), &MessageRole::user);
assert!(response.service_tier.is_none());
assert!(choice.finish_reason.is_none());
assert_eq!(choice.message.role, async_openai::types::Role::User);
assert!(response.inner.service_tier.is_none());
}
#[tokio::test]
......@@ -286,8 +302,14 @@ mod tests {
// Create multiple deltas with the same choice index
// One will have a MessageRole and no FinishReason,
// the other will have a FinishReason and no MessageRole
let annotated_delta1 = create_test_delta(0, "Hello,", Some(MessageRole::user), None);
let annotated_delta2 = create_test_delta(0, " world!", None, Some(FinishReason::stop));
let annotated_delta1 =
create_test_delta(0, "Hello,", Some(async_openai::types::Role::User), None);
let annotated_delta2 = create_test_delta(
0,
" world!",
None,
Some(async_openai::types::FinishReason::Stop),
);
// Create a stream
let annotated_deltas = vec![annotated_delta1, annotated_delta2];
......@@ -301,52 +323,63 @@ mod tests {
let response = result.unwrap();
// Verify the response fields
assert_eq!(response.choices.len(), 1);
let choice = &response.choices[0];
assert_eq!(response.inner.choices.len(), 1);
let choice = &response.inner.choices[0];
assert_eq!(choice.index, 0);
assert_eq!(choice.message.content.as_ref().unwrap(), "Hello, world!");
assert_eq!(choice.finish_reason, FinishReason::stop);
assert_eq!(choice.message.role.as_ref().unwrap(), &MessageRole::user);
assert_eq!(
choice.finish_reason,
Some(async_openai::types::FinishReason::Stop)
);
assert_eq!(choice.message.role, async_openai::types::Role::User);
}
#[allow(deprecated)]
#[tokio::test]
async fn test_multiple_choices() {
// Create a delta with multiple choices
let delta = ChatCompletionResponseDelta {
// ALLOW: function_call is deprecated
let delta = async_openai::types::CreateChatCompletionStreamResponse {
id: "test_id".to_string(),
model: "test_model".to_string(),
created: 1234567890,
service_tier: None,
usage: None,
system_fingerprint: None,
choices: vec![
ChatCompletionChoiceDelta {
async_openai::types::ChatChoiceStream {
index: 0,
delta: ChatCompletionContent {
role: Some(MessageRole::assistant),
delta: async_openai::types::ChatCompletionStreamResponseDelta {
role: Some(async_openai::types::Role::Assistant),
content: Some("Choice 0".to_string()),
function_call: None,
tool_calls: None,
refusal: None,
},
finish_reason: Some(FinishReason::stop),
finish_reason: Some(async_openai::types::FinishReason::Stop),
logprobs: None,
},
ChatCompletionChoiceDelta {
async_openai::types::ChatChoiceStream {
index: 1,
delta: ChatCompletionContent {
role: Some(MessageRole::assistant),
delta: async_openai::types::ChatCompletionStreamResponseDelta {
role: Some(async_openai::types::Role::Assistant),
content: Some("Choice 1".to_string()),
function_call: None,
tool_calls: None,
refusal: None,
},
finish_reason: Some(FinishReason::stop),
finish_reason: Some(async_openai::types::FinishReason::Stop),
logprobs: None,
},
],
object: "chat.completion".to_string(),
service_tier: None,
};
let data = ChatCompletionResponseDelta { inner: delta };
// Wrap it in Annotated and create a stream
let annotated_delta = Annotated {
data: Some(delta),
data: Some(data),
id: Some("test_id".to_string()),
event: None,
comment: None,
......@@ -361,24 +394,24 @@ mod tests {
let mut response = result.unwrap();
// Verify the response fields
assert_eq!(response.choices.len(), 2);
response.choices.sort_by(|a, b| a.index.cmp(&b.index)); // Ensure the choices are ordered
let choice0 = &response.choices[0];
assert_eq!(response.inner.choices.len(), 2);
response.inner.choices.sort_by(|a, b| a.index.cmp(&b.index)); // Ensure the choices are ordered
let choice0 = &response.inner.choices[0];
assert_eq!(choice0.index, 0);
assert_eq!(choice0.message.content.as_ref().unwrap(), "Choice 0");
assert_eq!(choice0.finish_reason, FinishReason::stop);
assert_eq!(
choice0.message.role.as_ref().unwrap(),
&MessageRole::assistant
choice0.finish_reason,
Some(async_openai::types::FinishReason::Stop)
);
assert_eq!(choice0.message.role, async_openai::types::Role::Assistant);
let choice1 = &response.choices[1];
let choice1 = &response.inner.choices[1];
assert_eq!(choice1.index, 1);
assert_eq!(choice1.message.content.as_ref().unwrap(), "Choice 1");
assert_eq!(choice1.finish_reason, FinishReason::stop);
assert_eq!(
choice1.message.role.as_ref().unwrap(),
&MessageRole::assistant
choice1.finish_reason,
Some(async_openai::types::FinishReason::Stop)
);
assert_eq!(choice1.message.role, async_openai::types::Role::Assistant);
}
}
......@@ -13,12 +13,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use super::{
ChatCompletionChoiceDelta, ChatCompletionContent, ChatCompletionRequest,
ChatCompletionResponseDelta, FinishReason, MessageRole, ServiceTier,
};
use super::{ChatCompletionRequest, ChatCompletionResponseDelta};
use crate::protocols::common;
use crate::protocols::openai::CompletionUsage;
impl ChatCompletionRequest {
// put this method on the request
......@@ -26,10 +22,10 @@ impl ChatCompletionRequest {
pub fn response_generator(&self) -> DeltaGenerator {
let options = DeltaGeneratorOptions {
enable_usage: true,
enable_logprobs: self.logprobs.unwrap_or(false),
enable_logprobs: self.inner.logprobs.unwrap_or(false),
};
DeltaGenerator::new(self.model.clone(), options)
DeltaGenerator::new(self.inner.model.clone(), options)
}
}
......@@ -43,11 +39,11 @@ pub struct DeltaGeneratorOptions {
pub struct DeltaGenerator {
id: String,
object: String,
created: u64,
created: u32,
model: String,
system_fingerprint: Option<String>,
service_tier: Option<ServiceTier>,
usage: CompletionUsage,
service_tier: Option<async_openai::types::ServiceTierResponse>,
usage: async_openai::types::CompletionUsage,
// counter on how many messages we have issued
msg_counter: u64,
......@@ -57,10 +53,22 @@ pub struct DeltaGenerator {
impl DeltaGenerator {
pub fn new(model: String, options: DeltaGeneratorOptions) -> Self {
// SAFETY: This is a fun one to write. We are casting from u64 to u32
// which typically is unsafe due to loss of precision after it
// exceeds u32::MAX. Fortunately, this won't be an issue until
// 2106. So whoever is still maintaining this then, enjoy!
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
.as_secs() as u32;
let usage = async_openai::types::CompletionUsage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
prompt_tokens_details: None,
completion_tokens_details: None,
};
Self {
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
......@@ -69,46 +77,54 @@ impl DeltaGenerator {
model,
system_fingerprint: None,
service_tier: None,
usage: CompletionUsage::default(),
usage,
msg_counter: 0,
options,
}
}
pub fn update_isl(&mut self, isl: i32) {
pub fn update_isl(&mut self, isl: u32) {
self.usage.prompt_tokens = isl;
}
#[allow(deprecated)]
pub fn create_choice(
&self,
index: u64,
index: u32,
text: Option<String>,
finish_reason: Option<super::FinishReason>,
logprobs: Option<super::ChatCompletionLogprobs>,
) -> ChatCompletionResponseDelta {
// todo - update for tool calling
let delta = ChatCompletionContent {
content: text,
finish_reason: Option<async_openai::types::FinishReason>,
logprobs: Option<async_openai::types::ChatChoiceLogprobs>,
) -> async_openai::types::CreateChatCompletionStreamResponse {
// TODO: Update for tool calling
// ALLOW: function_call is deprecated
let delta = async_openai::types::ChatCompletionStreamResponseDelta {
role: if self.msg_counter == 0 {
Some(MessageRole::assistant)
Some(async_openai::types::Role::Assistant)
} else {
None
},
content: text,
tool_calls: None,
function_call: None,
refusal: None,
};
ChatCompletionResponseDelta {
let choice = async_openai::types::ChatChoiceStream {
index,
delta,
finish_reason,
logprobs,
};
let choices = vec![choice];
async_openai::types::CreateChatCompletionStreamResponse {
id: self.id.clone(),
object: self.object.clone(),
created: self.created,
model: self.model.clone(),
system_fingerprint: self.system_fingerprint.clone(),
choices: vec![ChatCompletionChoiceDelta {
index,
delta,
finish_reason,
logprobs,
}],
choices,
usage: if self.options.enable_usage {
Some(self.usage.clone())
} else {
......@@ -126,17 +142,17 @@ impl crate::protocols::openai::DeltaGeneratorExt<ChatCompletionResponseDelta> fo
) -> anyhow::Result<ChatCompletionResponseDelta> {
// aggregate usage
if self.options.enable_usage {
self.usage.completion_tokens += delta.token_ids.len() as i32;
self.usage.completion_tokens += delta.token_ids.len() as u32;
}
// todo logprobs
let logprobs = None;
let finish_reason = match delta.finish_reason {
Some(common::FinishReason::EoS) => Some(FinishReason::stop),
Some(common::FinishReason::Stop) => Some(FinishReason::stop),
Some(common::FinishReason::Length) => Some(FinishReason::length),
Some(common::FinishReason::Cancelled) => Some(FinishReason::cancelled),
Some(common::FinishReason::EoS) => Some(async_openai::types::FinishReason::Stop),
Some(common::FinishReason::Stop) => Some(async_openai::types::FinishReason::Stop),
Some(common::FinishReason::Length) => Some(async_openai::types::FinishReason::Length),
Some(common::FinishReason::Cancelled) => Some(async_openai::types::FinishReason::Stop),
Some(common::FinishReason::Error(err_msg)) => {
return Err(anyhow::anyhow!(err_msg));
}
......@@ -145,6 +161,10 @@ impl crate::protocols::openai::DeltaGeneratorExt<ChatCompletionResponseDelta> fo
// create choice
let index = 0;
Ok(self.create_choice(index, delta.text, finish_reason, logprobs))
let stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);
Ok(ChatCompletionResponseDelta {
inner: stream_response,
})
}
}
......@@ -22,7 +22,7 @@ use validator::Validate;
mod aggregator;
mod delta;
pub use aggregator::DeltaAggregator;
// pub use aggregator::DeltaAggregator;
use super::{
common::{self, SamplingOptionsProvider, StopConditionsProvider},
......@@ -56,13 +56,13 @@ pub struct CompletionRequest {
/// The token count of your prompt plus max_tokens cannot exceed the model's context length.
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(into, strip_option))]
pub max_tokens: Option<i32>,
pub max_tokens: Option<u32>,
/// The minimum number of tokens to generate. We ignore stop tokens until we see this many
/// tokens. Leave this None unless you are working on the pre-processor.
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(into, strip_option))]
pub min_tokens: Option<i32>,
pub min_tokens: Option<u32>,
/// If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only
/// server-sent events as they become available, with the stream terminated by a data: \[DONE\]
......@@ -248,7 +248,7 @@ impl CompletionRequestBuilder {
/// let request = CompletionRequest::builder()
/// .model("mixtral-8x7b-instruct-v0.1")
/// .prompt("Hello")
/// .max_tokens(16)
/// .max_tokens(16_u32)
/// .build()
/// .expect("Failed to build CompletionRequest");
/// ```
......@@ -433,11 +433,11 @@ impl OpenAISamplingOptionsProvider for CompletionRequest {
}
impl OpenAIStopConditionsProvider for CompletionRequest {
fn get_max_tokens(&self) -> Option<i32> {
fn get_max_tokens(&self) -> Option<u32> {
self.max_tokens
}
fn get_min_tokens(&self) -> Option<i32> {
fn get_min_tokens(&self) -> Option<u32> {
self.min_tokens
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment