diff --git a/Cargo.lock b/Cargo.lock index f6138e14ae641f2c9c4951e3fa2cbb9373ac3d33..2673dd53a2fe475dbb36da0f99d0554d2c4419ff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "addr2line" -version = "0.22.0" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e4503c46a5c0c7844e948c9a4d6acd9f50cccb4de1c48eb9e291ea17470c678" +checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" dependencies = [ "gimli", ] @@ -17,6 +17,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + [[package]] name = "ahash" version = "0.8.11" @@ -46,11 +52,17 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1" +[[package]] +name = "allocator-api2" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" + [[package]] name = "anstream" -version = "0.6.14" +version = "0.6.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "418c75fa768af9c03be99d17643f93f79bbba589895012a80e3452a19ddda15b" +checksum = "23a1e53f0f5d86382dafe1cf314783b2044280f406e7e1506368220ad11b1338" dependencies = [ "anstyle", "anstyle-parse", @@ -63,43 +75,43 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.7" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b" +checksum = "8365de52b16c035ff4fcafe0092ba9390540e3e352870ac09933bebcaa2c8c56" [[package]] name = "anstyle-parse" -version = "0.2.4" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c03a11a9034d92058ceb6ee011ce58af4a9bf61491aa7e1e59ecd24bd40d22d4" +checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" -version = "1.1.0" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad186efb764318d35165f1758e7dcef3b10628e26d41a44bc5550652e6804391" +checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] name = "anstyle-wincon" -version = "3.0.3" +version = "3.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61a38449feb7068f52bb06c12759005cf459ee52bb4adc1d5a7c4322d716fb19" +checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" dependencies = [ "anstyle", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] name = "anyhow" -version = "1.0.86" +version = "1.0.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" +checksum = "c042108f3ed77fd83760a5fd79b53be043192bb3b9dba91d8c574c0ada7850c8" [[package]] name = "arbitrary" @@ -121,14 +133,14 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.85", ] [[package]] name = "arrayvec" -version = "0.7.4" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "async-rustls" @@ -143,9 +155,9 @@ dependencies = [ [[package]] name = "async-stream" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" dependencies = [ "async-stream-impl", "futures-core", @@ -154,24 +166,24 @@ dependencies = [ [[package]] name = "async-stream-impl" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.85", ] [[package]] name = "async-trait" -version = "0.1.80" +version = "0.1.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" +checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.85", ] [[package]] @@ -180,11 +192,22 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "atty" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" +dependencies = [ + "hermit-abi 0.1.19", + "libc", + "winapi", +] + [[package]] name = "autocfg" -version = "1.3.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "av1-grain" @@ -213,9 +236,9 @@ dependencies = [ [[package]] name = "avif-serialize" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "876c75a42f6364451a033496a14c44bffe41f5f4a8236f697391f11024e596d2" +checksum = "e335041290c43101ca215eed6f43ec437eb5a42125573f600fc3fa42b9bddd62" dependencies = [ "arrayvec", ] @@ -234,9 +257,9 @@ dependencies = [ [[package]] name = "aws-lc-rs" -version = "1.7.3" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf7d844e282b4b56750b2d4e893b2205581ded8709fddd2b6aa5418c150ca877" +checksum = "cdd82dba44d209fddb11c190e0a94b78651f95299598e472215667417a03ff1d" dependencies = [ "aws-lc-sys", "mirai-annotations", @@ -246,9 +269,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.18.0" +version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3a2c29203f6bf296d01141cc8bb9dbd5ecd4c27843f2ee0767bcd5985a927da" +checksum = "df7a4168111d7eb622a31b214057b8509c0a7e1794f44c546d742330dc793972" dependencies = [ "bindgen", "cc", @@ -272,7 +295,7 @@ dependencies = [ "futures-util", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.29", + "hyper 0.14.31", "itoa", "matchit", "memchr", @@ -286,25 +309,25 @@ dependencies = [ "serde_urlencoded", "sync_wrapper 0.1.2", "tokio", - "tower", + "tower 0.4.13", "tower-layer", "tower-service", ] [[package]] name = "axum" -version = "0.7.5" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf" +checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae" dependencies = [ "async-trait", - "axum-core 0.4.3", + "axum-core 0.4.5", "bytes", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "http-body-util", - "hyper 1.3.1", + "hyper 1.5.0", "hyper-util", "itoa", "matchit", @@ -319,7 +342,7 @@ dependencies = [ "serde_urlencoded", "sync_wrapper 1.0.1", "tokio", - "tower", + "tower 0.5.1", "tower-layer", "tower-service", "tracing", @@ -344,20 +367,20 @@ dependencies = [ [[package]] name = "axum-core" -version = "0.4.3" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a15c63fd72d41492dc4f497196f5da1fb04fb7529e631d73630d1b491e47a2e3" +checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" dependencies = [ "async-trait", "bytes", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "http-body-util", "mime", "pin-project-lite", "rustversion", - "sync_wrapper 0.1.2", + "sync_wrapper 1.0.1", "tower-layer", "tower-service", "tracing", @@ -369,13 +392,13 @@ version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bdad298231394729042d1f155b93f9fdf0b5ee1aea0b62404c4d7341f7d8fe08" dependencies = [ - "axum 0.7.5", + "axum 0.7.7", "futures-core", "futures-util", "http 1.1.0", "opentelemetry 0.21.0", "pin-project-lite", - "tower", + "tower 0.4.13", "tracing", "tracing-opentelemetry 0.22.0", "tracing-opentelemetry-instrumentation-sdk", @@ -383,17 +406,17 @@ dependencies = [ [[package]] name = "backtrace" -version = "0.3.73" +version = "0.3.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cc23269a4f8976d0a4d2e7109211a419fe30e8d88d677cd60b6bc79c5732e0a" +checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" dependencies = [ "addr2line", - "cc", "cfg-if", "libc", - "miniz_oxide", + "miniz_oxide 0.8.0", "object", "rustc-demangle", + "windows-targets 0.52.6", ] [[package]] @@ -416,9 +439,9 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "bindgen" -version = "0.69.4" +version = "0.69.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a00dc851838a2120612785d195287475a3ac45514741da670b735818822129a0" +checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" dependencies = [ "bitflags 2.6.0", "cexpr", @@ -433,7 +456,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.68", + "syn 2.0.85", "which", ] @@ -472,9 +495,9 @@ checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" [[package]] name = "bitstream-io" -version = "2.4.2" +version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "415f8399438eb5e4b2f73ed3152a3448b98149dda642a957ee704e1daa5cf1d8" +checksum = "b81e1519b0d82120d2fd469d5bfb2919a9361c48b02d82d04befc1cdd2002452" [[package]] name = "block-buffer" @@ -487,9 +510,9 @@ dependencies = [ [[package]] name = "built" -version = "0.7.3" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6a6c0b39c38fd754ac338b00a88066436389c0f029da5d37d1e01091d9b7c17" +checksum = "c360505aed52b7ec96a3636c3f039d99103c37d1d9b4f7a8c743d3ea9ffcd03b" [[package]] name = "bumpalo" @@ -505,9 +528,9 @@ checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce" [[package]] name = "bytemuck" -version = "1.16.1" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b236fc92302c97ed75b38da1f4917b5cdda4984745740f153a5d3059e48d725e" +checksum = "8334215b81e418a0a7bdb8ef0849474f40bb10c8b71f1c4ed315cff49f32494d" [[package]] name = "byteorder" @@ -523,15 +546,15 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" [[package]] name = "bytes" -version = "1.6.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" +checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" [[package]] name = "camino" -version = "1.1.7" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0ec6b951b160caa93cc0c7b209e5a3bff7aae9062213451ac99493cd844c239" +checksum = "8b96ec4966b5813e2c0507c1f86115c8c5abaadc3980879c3424042a02fd1ad3" dependencies = [ "serde", ] @@ -565,15 +588,30 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + +[[package]] +name = "castaway" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0abae9be0aaf9ea96a3b1b8b1b55c602ca751eba1b1500220cea4ecbafe7c0d5" +dependencies = [ + "rustversion", +] + [[package]] name = "cc" -version = "1.0.101" +version = "1.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac367972e516d45567c7eafc73d24e1c193dcf200a8d94e9db7b3d38b349572d" +checksum = "c2e7962b54006dcfcc61cb72735f4d89bb97061dd6a7ed882ec6b8ee53714c6f" dependencies = [ "jobserver", "libc", - "once_cell", + "shlex", ] [[package]] @@ -607,6 +645,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "clang-sys" version = "1.8.1" @@ -620,9 +664,20 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.7" +version = "2.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c" +dependencies = [ + "bitflags 1.3.2", + "textwrap", + "unicode-width", +] + +[[package]] +name = "clap" +version = "4.5.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5db83dced34638ad474f39f250d7fea9598bdd239eaced1bdf45d597da0f433f" +checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8" dependencies = [ "clap_builder", "clap_derive", @@ -630,9 +685,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.7" +version = "4.5.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7e204572485eb3fbf28f871612191521df159bc3e15a9f5064c66dba3a8c05f" +checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54" dependencies = [ "anstream", "anstyle", @@ -642,31 +697,41 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.5" +version = "4.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c780290ccf4fb26629baa7a1081e68ced113f1d3ec302fa5948f1c381ebf06c6" +checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.85", ] [[package]] name = "clap_lex" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b82cf0babdbd58558212896d1a4272303a57bdb245c2bf1147185fb45640e70" +checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" [[package]] name = "cmake" -version = "0.1.50" +version = "0.1.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a31c789563b815f77f4250caee12365734369f942439b7defd71e18a48197130" +checksum = "fb1e43aa7fd152b1f968787f7dbcdeb306d1867ff373c69955211876c053f91a" dependencies = [ "cc", ] +[[package]] +name = "codespan-reporting" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e" +dependencies = [ + "termcolor", + "unicode-width", +] + [[package]] name = "color_quant" version = "1.1.0" @@ -675,9 +740,23 @@ checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" [[package]] name = "colorchoice" -version = "1.0.1" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422" +checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" + +[[package]] +name = "compact_str" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6050c3a16ddab2e412160b31f2c871015704239bca62f72f6e5f0be631d3f644" +dependencies = [ + "castaway", + "cfg-if", + "itoa", + "rustversion", + "ryu", + "static_assertions", +] [[package]] name = "console" @@ -704,15 +783,15 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.6" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "cpufeatures" -version = "0.2.12" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" +checksum = "608697df725056feaccfa42cffdaeeec3fccc4ffc38358ecd19b243e716a78e0" dependencies = [ "libc", ] @@ -726,6 +805,42 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b01d6de93b2b6c65e17c634a26653a29d107b3c98c607c765bf38d041531cd8f" +dependencies = [ + "atty", + "cast", + "clap 2.34.0", + "criterion-plot", + "csv", + "itertools 0.10.5", + "lazy_static", + "num-traits", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_cbor", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2673cc8207403546f45f5fd319a974b1e6983ad1a3ee7e6041650013be041876" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "crossbeam-channel" version = "0.5.13" @@ -762,15 +877,15 @@ checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" [[package]] name = "crossterm" -version = "0.27.0" +version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df" +checksum = "829d955a0bb380ef178a640b91779e3987da38c9aea133b20614cfed8cdea9c6" dependencies = [ "bitflags 2.6.0", "crossterm_winapi", - "libc", "mio", "parking_lot", + "rustix", "signal-hook", "signal-hook-mio", "winapi", @@ -801,21 +916,86 @@ dependencies = [ "typenum", ] +[[package]] +name = "csv" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" +dependencies = [ + "memchr", +] + [[package]] name = "ctrlc" -version = "3.4.4" +version = "3.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "672465ae37dc1bc6380a6547a8883d5dd397b0f1faaad4f265726cc7042a5345" +checksum = "90eeab0aa92f3f9b4e87f258c72b139c207d251f9cbc1080a0086b86a8870dd3" dependencies = [ - "nix", - "windows-sys 0.52.0", + "nix 0.29.0", + "windows-sys 0.59.0", +] + +[[package]] +name = "cxx" +version = "1.0.129" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cbdc8cca144dce1c4981b5c9ab748761619979e515c3d53b5df385c677d1d007" +dependencies = [ + "cc", + "cxxbridge-flags", + "cxxbridge-macro", + "link-cplusplus", +] + +[[package]] +name = "cxx-build" +version = "1.0.129" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5764c3142ab44fcf857101d12c0ddf09c34499900557c764f5ad0597159d1fc" +dependencies = [ + "cc", + "codespan-reporting", + "once_cell", + "proc-macro2", + "quote", + "scratch", + "syn 2.0.85", +] + +[[package]] +name = "cxxbridge-flags" +version = "1.0.129" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d422aff542b4fa28c2ce8e5cc202d42dbf24702345c1fba3087b2d3f8a1b90ff" + +[[package]] +name = "cxxbridge-macro" +version = "1.0.129" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1719100f31492cd6adeeab9a0f46cdbc846e615fdb66d7b398aa46ec7fdd06f" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.85", ] [[package]] name = "darling" -version = "0.20.9" +version = "0.20.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83b2eb4d90d12bdda5ed17de686c2acb4c57914f8f921b8da7e112b5a36f3fe1" +checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" dependencies = [ "darling_core", "darling_macro", @@ -823,27 +1003,27 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.20.9" +version = "0.20.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "622687fe0bac72a04e5599029151f5796111b90f1baaa9b544d807a5e31cd120" +checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" dependencies = [ "fnv", "ident_case", "proc-macro2", "quote", "strsim", - "syn 2.0.68", + "syn 2.0.85", ] [[package]] name = "darling_macro" -version = "0.20.9" +version = "0.20.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "733cabb43482b1a1b53eee8583c2b9e8684d592215ea83efd305dd31bc2f0178" +checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core", "quote", - "syn 2.0.68", + "syn 2.0.85", ] [[package]] @@ -857,33 +1037,33 @@ dependencies = [ [[package]] name = "derive_builder" -version = "0.20.0" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0350b5cb0331628a5916d6c5c0b72e97393b8b6b03b47a9284f4e7f5a405ffd7" +checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" dependencies = [ "derive_builder_macro", ] [[package]] name = "derive_builder_core" -version = "0.20.0" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d48cda787f839151732d396ac69e3473923d54312c070ee21e9effcaa8ca0b1d" +checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.85", ] [[package]] name = "derive_builder_macro" -version = "0.20.0" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b" +checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" dependencies = [ "derive_builder_core", - "syn 2.0.68", + "syn 2.0.85", ] [[package]] @@ -919,9 +1099,9 @@ dependencies = [ [[package]] name = "dunce" -version = "1.0.4" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56ce8c6da7551ec6c462cbaf3bfbc75131ebbfa1c944aeaa9dab51ca1c5f0c3b" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" [[package]] name = "easy-cast" @@ -946,9 +1126,9 @@ checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" [[package]] name = "encoding_rs" -version = "0.8.34" +version = "0.8.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b45de904aa0b010bce2ab45264d0631681847fa7b6f2eaa7dab7619943bc4f59" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" dependencies = [ "cfg-if", ] @@ -986,9 +1166,9 @@ checksum = "887d93f60543e9a9362ef8a21beedd0a833c5d9610e18c67abe15a5963dcb1a4" dependencies = [ "bit_field", "flume", - "half", + "half 2.4.1", "lebe", - "miniz_oxide", + "miniz_oxide 0.7.4", "rayon-core", "smallvec", "zune-inflate", @@ -1006,15 +1186,15 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.1.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" +checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" [[package]] name = "fdeflate" -version = "0.3.4" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f9bfee30e4dedf0ab8b422f03af778d9612b63f502710fc500a334ebe2de645" +checksum = "d8090f921a24b04994d9929e204f50b498a33ea6ba559ffaa05e04f7ee7fb5ab" dependencies = [ "simd-adler32", ] @@ -1027,12 +1207,12 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" [[package]] name = "flate2" -version = "1.0.30" +version = "1.0.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae" +checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0" dependencies = [ "crc32fast", - "miniz_oxide", + "miniz_oxide 0.8.0", ] [[package]] @@ -1049,9 +1229,9 @@ checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853" [[package]] name = "flume" -version = "0.11.0" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181" +checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095" dependencies = [ "spin 0.9.8", ] @@ -1062,6 +1242,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2" + [[package]] name = "foreign-types" version = "0.3.2" @@ -1104,9 +1290,9 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" [[package]] name = "futures" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" dependencies = [ "futures-channel", "futures-core", @@ -1119,9 +1305,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", "futures-sink", @@ -1129,15 +1315,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" [[package]] name = "futures-executor" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" dependencies = [ "futures-core", "futures-task", @@ -1146,38 +1332,38 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" [[package]] name = "futures-macro" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.85", ] [[package]] name = "futures-sink" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" [[package]] name = "futures-task" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" [[package]] name = "futures-util" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-channel", "futures-core", @@ -1235,9 +1421,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.29.0" +version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" +checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "glob" @@ -1267,7 +1453,7 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.12", - "indexmap 2.2.6", + "indexmap 2.6.0", "slab", "tokio", "tokio-util", @@ -1276,9 +1462,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa82e28a107a8cc405f0839610bdc9b15f1e25ec7d696aa5cf173edbcb1486ab" +checksum = "524e8ac6999421f49a846c2d4411f337e53497d8ec55d67753beffa43c5d9205" dependencies = [ "atomic-waker", "bytes", @@ -1286,13 +1472,19 @@ dependencies = [ "futures-core", "futures-sink", "http 1.1.0", - "indexmap 2.2.6", + "indexmap 2.6.0", "slab", "tokio", "tokio-util", "tracing", ] +[[package]] +name = "half" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403" + [[package]] name = "half" version = "2.4.1" @@ -1316,6 +1508,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" dependencies = [ "ahash", + "allocator-api2", +] + +[[package]] +name = "hashbrown" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", ] [[package]] @@ -1330,6 +1534,15 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" +dependencies = [ + "libc", +] + [[package]] name = "hermit-abi" version = "0.3.9" @@ -1412,9 +1625,9 @@ dependencies = [ [[package]] name = "http-body" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", "http 1.1.0", @@ -1429,15 +1642,15 @@ dependencies = [ "bytes", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "pin-project-lite", ] [[package]] name = "httparse" -version = "1.9.4" +version = "1.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fcc0b4a115bf80b728eb8ea024ad5bd707b615bfed49e0665b6e0f86fd082d9" +checksum = "7d71d3574edd2771538b901e6549113b4006ece66150fb69c0fb6d9a2adae946" [[package]] name = "httpdate" @@ -1447,9 +1660,9 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "hyper" -version = "0.14.29" +version = "0.14.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f361cde2f109281a220d4307746cdfd5ee3f410da58a70377762396775634b33" +checksum = "8c08302e8fa335b151b788c775ff56e7a03ae64ff85c548ee820fecb70356e85" dependencies = [ "bytes", "futures-channel", @@ -1471,16 +1684,16 @@ dependencies = [ [[package]] name = "hyper" -version = "1.3.1" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe575dd17d0862a9a33781c8c4696a55c320909004a67a00fb286ba8b1bc496d" +checksum = "bbbff0a806a4728c99295b254c8838933b5b082d75e3cb70c8dab21fdfbcfa9a" dependencies = [ "bytes", "futures-channel", "futures-util", - "h2 0.4.5", + "h2 0.4.6", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "httparse", "httpdate", "itoa", @@ -1492,16 +1705,16 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.27.2" +version = "0.27.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ee4be2c948921a1a5320b629c4193916ed787a7f7f293fd3f7f5a6c9de74155" +checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" dependencies = [ "futures-util", "http 1.1.0", - "hyper 1.3.1", + "hyper 1.5.0", "hyper-util", "log", - "rustls 0.23.10", + "rustls 0.23.15", "rustls-native-certs", "rustls-pki-types", "tokio", @@ -1515,7 +1728,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" dependencies = [ - "hyper 0.14.29", + "hyper 0.14.31", "pin-project-lite", "tokio", "tokio-io-timeout", @@ -1528,7 +1741,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" dependencies = [ "bytes", - "hyper 0.14.29", + "hyper 0.14.31", "native-tls", "tokio", "tokio-native-tls", @@ -1536,20 +1749,19 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.5" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b875924a60b96e5d7b9ae7b066540b1dd1cbd90d1828f54c92e02a283351c56" +checksum = "41296eb09f183ac68eec06e03cdbea2e759633d4067b2f6552fc2e009bcad08b" dependencies = [ "bytes", "futures-channel", "futures-util", "http 1.1.0", - "http-body 1.0.0", - "hyper 1.3.1", + "http-body 1.0.1", + "hyper 1.5.0", "pin-project-lite", "socket2", "tokio", - "tower", "tower-service", "tracing", ] @@ -1572,12 +1784,12 @@ dependencies = [ [[package]] name = "image" -version = "0.25.1" +version = "0.25.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd54d660e773627692c524beaad361aca785a4f9f5730ce91f42aabe5bce3d11" +checksum = "bc144d44a31d753b02ce64093d532f55ff8dc4ebf2ffb8a63c0dda691385acae" dependencies = [ "bytemuck", - "byteorder", + "byteorder-lite", "color_quant", "exr", "gif", @@ -1595,19 +1807,19 @@ dependencies = [ [[package]] name = "image-webp" -version = "0.1.2" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d730b085583c4d789dfd07fdcf185be59501666a90c97c40162b37e4fdad272d" +checksum = "e031e8e3d94711a9ccb5d6ea357439ef3dcbed361798bd4071dc4d9793fbe22f" dependencies = [ "byteorder-lite", - "thiserror", + "quick-error", ] [[package]] name = "imgref" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44feda355f4159a7c757171a77de25daf6411e217b4cabd03bd6650690468126" +checksum = "d0263a3d970d5c054ed9312c0057b4f3bde9c0b33836d3637361d4a9e6e7a408" [[package]] name = "indexmap" @@ -1621,12 +1833,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.6" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" +checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", - "hashbrown 0.14.5", + "hashbrown 0.15.0", "serde", ] @@ -1662,6 +1874,16 @@ dependencies = [ "tracing-opentelemetry 0.21.0", ] +[[package]] +name = "instability" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b23a0c8dfe501baac4adf6ebbfa6eddf8f0c07f56b058cc1288017e32397846c" +dependencies = [ + "quote", + "syn 2.0.85", +] + [[package]] name = "instant" version = "0.1.13" @@ -1679,20 +1901,20 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.85", ] [[package]] name = "ipnet" -version = "2.9.0" +version = "2.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" +checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708" [[package]] name = "is_terminal_polyfill" -version = "1.70.0" +version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" [[package]] name = "iso8601" @@ -1730,6 +1952,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" @@ -1738,9 +1969,9 @@ checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "jobserver" -version = "0.1.31" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2b099aaa34a9751c5bf0878add70444e1ed2dd73f347be99003d4577277de6e" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" dependencies = [ "libc", ] @@ -1753,9 +1984,9 @@ checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0" [[package]] name = "js-sys" -version = "0.3.69" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" dependencies = [ "wasm-bindgen", ] @@ -1770,7 +2001,7 @@ dependencies = [ "anyhow", "base64 0.21.7", "bytecount", - "clap", + "clap 4.5.20", "fancy-regex", "fraction", "getrandom", @@ -1810,9 +2041,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" [[package]] name = "libc" -version = "0.2.155" +version = "0.2.161" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" [[package]] name = "libfuzzer-sys" @@ -1827,12 +2058,12 @@ dependencies = [ [[package]] name = "libloading" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e310b3a6b5907f99202fcdb4960ff45b93735d7c7d96b760fcff8db2dc0e103d" +checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" dependencies = [ "cfg-if", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -1851,6 +2082,15 @@ dependencies = [ "libc", ] +[[package]] +name = "link-cplusplus" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d240c6f7e1ba3a28b0249f774e6a9dd0175054b52dfbb61b16eb8505c3785c9" +dependencies = [ + "cc", +] + [[package]] name = "linux-raw-sys" version = "0.4.14" @@ -1869,9 +2109,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.21" +version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "loop9" @@ -1882,6 +2122,15 @@ dependencies = [ "imgref", ] +[[package]] +name = "lru" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" +dependencies = [ + "hashbrown 0.15.0", +] + [[package]] name = "macro_rules_attribute" version = "0.2.0" @@ -1926,7 +2175,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ea1f30cedd69f0a2954655f7188c6a834246d2bcf1e315e2ac40c4b24dc9519" dependencies = [ "cfg-if", - "rayon", ] [[package]] @@ -1936,14 +2184,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] -name = "metrics" -version = "0.21.1" +name = "memoffset" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fde3af1a009ed76a778cb84fdef9e7dbbdf5775ae3e4cc1f434a6a307f6f76c5" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" dependencies = [ - "ahash", - "metrics-macros", - "portable-atomic", + "autocfg", ] [[package]] @@ -1958,18 +2204,18 @@ dependencies = [ [[package]] name = "metrics-exporter-prometheus" -version = "0.15.1" +version = "0.15.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf0af7a0d7ced10c0151f870e5e3f3f8bc9ffc5992d32873566ca1f9169ae776" +checksum = "b4f0c8427b39666bf970460908b213ec09b3b350f20c0c2eabcbba51704a08e6" dependencies = [ "base64 0.22.1", "http-body-util", - "hyper 1.3.1", + "hyper 1.5.0", "hyper-rustls", "hyper-util", - "indexmap 2.2.6", + "indexmap 2.6.0", "ipnet", - "metrics 0.23.0", + "metrics", "metrics-util", "quanta", "thiserror", @@ -1977,17 +2223,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "metrics-macros" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38b4faf00617defe497754acde3024865bc143d44a86799b24e191ecff91354f" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.68", -] - [[package]] name = "metrics-util" version = "0.17.0" @@ -1997,7 +2232,7 @@ dependencies = [ "crossbeam-epoch", "crossbeam-utils", "hashbrown 0.14.5", - "metrics 0.23.0", + "metrics", "num_cpus", "quanta", "sketches-ddsketch", @@ -2011,9 +2246,9 @@ checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" [[package]] name = "mime_guess" -version = "2.0.4" +version = "2.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" dependencies = [ "mime", "unicase", @@ -2021,18 +2256,19 @@ dependencies = [ [[package]] name = "minijinja" -version = "2.0.2" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e136ef580d7955019ab0a407b68d77c292a9976907e217900f3f76bc8f6dc1a4" +checksum = "c9ca8daf4b0b4029777f1bc6e1aedd1aec7b74c276a43bc6f620a8e1a1c0a90e" dependencies = [ "serde", + "serde_json", ] [[package]] name = "minijinja-contrib" -version = "2.0.2" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15ee37078c98d31e510d6a7af488031a2c3ccacdb76c5c4fc98ddfe6d0e9da07" +checksum = "39ffd46ee854be23604a20efd6c9655374fefbe4d44b949dc0f907305d92873a" dependencies = [ "minijinja", "serde", @@ -2051,19 +2287,29 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08" dependencies = [ "adler", +] + +[[package]] +name = "miniz_oxide" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +dependencies = [ + "adler2", "simd-adler32", ] [[package]] name = "mio" -version = "0.8.11" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" +checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" dependencies = [ + "hermit-abi 0.3.9", "libc", "log", "wasi", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -2090,7 +2336,7 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.85", ] [[package]] @@ -2156,7 +2402,7 @@ dependencies = [ "bytes", "futures", "hostname", - "hyper 0.14.29", + "hyper 0.14.31", "muxado", "once_cell", "parking_lot", @@ -2180,7 +2426,19 @@ checksum = "ab2156c4fce2f8df6c499cc1c763e4394b7482525bf2a9701c9d79d215f519e4" dependencies = [ "bitflags 2.6.0", "cfg-if", - "cfg_aliases", + "cfg_aliases 0.1.1", + "libc", +] + +[[package]] +name = "nix" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" +dependencies = [ + "bitflags 2.6.0", + "cfg-if", + "cfg_aliases 0.2.1", "libc", ] @@ -2241,9 +2499,9 @@ dependencies = [ [[package]] name = "num-bigint" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c165a9ab64cf766f73521c0dd2cfdff64f488b8f0b3e621face3462d3db536d7" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" dependencies = [ "num-integer", "num-traits", @@ -2278,7 +2536,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.85", ] [[package]] @@ -2328,7 +2586,7 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi", + "hermit-abi 0.3.9", "libc", ] @@ -2349,18 +2607,18 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" [[package]] name = "object" -version = "0.36.0" +version = "0.36.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "576dfe1fc8f9df304abb159d767a29d0476f7750fbf8aa7ad07816004a207434" +checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e" dependencies = [ "memchr", ] [[package]] name = "once_cell" -version = "1.19.0" +version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" [[package]] name = "onig" @@ -2384,11 +2642,17 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "oorandom" +version = "11.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" + [[package]] name = "openssl" -version = "0.10.64" +version = "0.10.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" +checksum = "6174bc48f102d208783c2c84bf931bb75927a617866870de8a4ea85597f871f5" dependencies = [ "bitflags 2.6.0", "cfg-if", @@ -2407,7 +2671,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.85", ] [[package]] @@ -2418,9 +2682,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.102" +version = "0.9.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c597637d56fbc83893a35eb0dd04b2b8e7a50c91e64e9493e398b5df4fb45fa2" +checksum = "45abf306cbf99debc8195b66b7346498d7b10c210de50418b5ccd7ceba08c741" dependencies = [ "cc", "libc", @@ -2446,7 +2710,7 @@ checksum = "1e32339a5dc40459130b3bd269e9892439f55b33e772d2a9d402a789baaf4e8a" dependencies = [ "futures-core", "futures-sink", - "indexmap 2.2.6", + "indexmap 2.6.0", "js-sys", "once_cell", "pin-project-lite", @@ -2454,6 +2718,20 @@ dependencies = [ "urlencoding", ] +[[package]] +name = "opentelemetry" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c365a63eec4f55b7efeceb724f1336f26a9cf3427b70e59e2cd2a5b947fba96" +dependencies = [ + "futures-core", + "futures-sink", + "js-sys", + "once_cell", + "pin-project-lite", + "thiserror", +] + [[package]] name = "opentelemetry-otlp" version = "0.13.0" @@ -2547,7 +2825,25 @@ dependencies = [ "glob", "once_cell", "opentelemetry 0.21.0", - "ordered-float 4.2.0", + "ordered-float 4.4.0", + "percent-encoding", + "rand", + "thiserror", +] + +[[package]] +name = "opentelemetry_sdk" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "692eac490ec80f24a17828d49b40b60f5aeaccdfe6a503f939713afd22bc28df" +dependencies = [ + "async-trait", + "futures-channel", + "futures-executor", + "futures-util", + "glob", + "once_cell", + "opentelemetry 0.24.0", "percent-encoding", "rand", "thiserror", @@ -2570,9 +2866,9 @@ dependencies = [ [[package]] name = "ordered-float" -version = "4.2.0" +version = "4.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a76df7075c7d4d01fdcb46c912dd17fba5b60c78ea480b475f2b6ab6f666584e" +checksum = "83e7ccb95e240b7c9506a3d544f10d935e142cc90b0a1d56954fb44d89ad6b97" dependencies = [ "num-traits", ] @@ -2614,7 +2910,7 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -2636,34 +2932,34 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" dependencies = [ "fixedbitset", - "indexmap 2.2.6", + "indexmap 2.6.0", ] [[package]] name = "pin-project" -version = "1.1.5" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bf43b791c5b9e34c3d182969b4abb522f9343702850a2e57f460d00d09b4b3" +checksum = "be57f64e946e500c8ee36ef6331845d40a93055567ec57e8fae13efd33759b95" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.5" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" +checksum = "3c0f5fad0874fc7abcd4d750e76917eaebbecaa2c20bde22e1dbeeba8beb758c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.85", ] [[package]] name = "pin-project-lite" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" [[package]] name = "pin-utils" @@ -2673,28 +2969,56 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" +checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" + +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] [[package]] name = "png" -version = "0.17.13" +version = "0.17.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06e4b0d3d1312775e782c86c91a111aa1f910cbb65e1337f9975b5f9a554b5e1" +checksum = "52f9d46a34a05a6a57566bc2bfae066ef07585a6e3fa30fbbdff5936380623f0" dependencies = [ "bitflags 1.3.2", "crc32fast", "fdeflate", "flate2", - "miniz_oxide", + "miniz_oxide 0.8.0", ] [[package]] name = "portable-atomic" -version = "1.6.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" +checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" [[package]] name = "powerfmt" @@ -2704,18 +3028,21 @@ checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" -version = "0.2.17" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] [[package]] name = "prettyplease" -version = "0.2.20" +version = "0.2.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f12335488a2f3b0a83b14edad48dca9879ce89b2edd10e80237e4e852dd645e" +checksum = "64d1ec885c64d0457d564db4ec299b2dae3f9c02808b8ad9c3a089c591b18033" dependencies = [ "proc-macro2", - "syn 2.0.68", + "syn 2.0.85", ] [[package]] @@ -2744,30 +3071,30 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.86" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" dependencies = [ "unicode-ident", ] [[package]] name = "profiling" -version = "1.0.15" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43d84d1d7a6ac92673717f9f6d1518374ef257669c24ebc5ac25d5033828be58" +checksum = "afbdc74edc00b6f6a218ca6a5364d6226a259d4b8ea1af4a0ea063f27e179f4d" dependencies = [ "profiling-procmacros", ] [[package]] name = "profiling-procmacros" -version = "1.0.15" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd" +checksum = "a65f2e60fbf1063868558d69c6beacf412dc755f9fc020f514b7955fc914fe30" dependencies = [ "quote", - "syn 2.0.68", + "syn 2.0.85", ] [[package]] @@ -2807,7 +3134,7 @@ dependencies = [ "prost 0.12.6", "prost-types", "regex", - "syn 2.0.68", + "syn 2.0.85", "tempfile", ] @@ -2834,7 +3161,7 @@ dependencies = [ "itertools 0.12.1", "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.85", ] [[package]] @@ -2846,6 +3173,69 @@ dependencies = [ "prost 0.12.6", ] +[[package]] +name = "pyo3" +version = "0.22.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d922163ba1f79c04bc49073ba7b32fd5a8d3b76a87c955921234b8e77333c51" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.22.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc38c5feeb496c8321091edf3d63e9a6829eab4b863b4a6a65f26f3e9cc6b179" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.22.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94845622d88ae274d2729fcefc850e63d7a3ddff5e3ce11bd88486db9f1d357d" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.22.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e655aad15e09b94ffdb3ce3d217acf652e26bbc37697ef012f5e5e348c716e5e" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 2.0.85", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.22.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae1e3f09eecd94618f60a455a23def79f79eba4dc561a97324bf9ac8c6df30ce" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn 2.0.85", +] + [[package]] name = "qoi" version = "0.4.1" @@ -2878,9 +3268,9 @@ checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" [[package]] name = "quote" -version = "1.0.36" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" dependencies = [ "proc-macro2", ] @@ -2917,18 +3307,22 @@ dependencies = [ [[package]] name = "ratatui" -version = "0.23.0" +version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e2e4cd95294a85c3b4446e63ef054eea43e0205b1fd60120c16b74ff7ff96ad" +checksum = "fdef7f9be5c0122f890d58bdf4d964349ba6a6161f705907526d891efabba57d" dependencies = [ "bitflags 2.6.0", "cassowary", + "compact_str", "crossterm", - "indoc", - "itertools 0.11.0", + "instability", + "itertools 0.13.0", + "lru", "paste", "strum", + "strum_macros", "unicode-segmentation", + "unicode-truncate", "unicode-width", ] @@ -2969,24 +3363,23 @@ dependencies = [ [[package]] name = "ravif" -version = "0.11.7" +version = "0.11.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67376f469e7e7840d0040bbf4b9b3334005bb167f814621326e4c7ab8cd6e944" +checksum = "2413fd96bd0ea5cdeeb37eaf446a22e6ed7b981d792828721e74ded1980a45c6" dependencies = [ "avif-serialize", "imgref", "loop9", "quick-error", "rav1e", - "rayon", "rgb", ] [[package]] name = "raw-cpuid" -version = "11.0.2" +version = "11.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e29830cbb1290e404f24c73af91c5d8d631ce7e128691e9477556b540cd01ecd" +checksum = "1ab240315c661615f2ee9f0f2cd32d5a7343a84d5ebcccb99d46e6637565e7b0" dependencies = [ "bitflags 2.6.0", ] @@ -3024,18 +3417,18 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.2" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c82cf8cff14456045f55ec4241383baeff27af886adb72ffb2162f99911de0fd" +checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" dependencies = [ "bitflags 2.6.0", ] [[package]] name = "redox_users" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891" +checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ "getrandom", "libredox", @@ -3044,14 +3437,14 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.5" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.7", - "regex-syntax 0.8.4", + "regex-automata 0.4.8", + "regex-syntax 0.8.5", ] [[package]] @@ -3065,13 +3458,13 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.7" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.4", + "regex-syntax 0.8.5", ] [[package]] @@ -3082,9 +3475,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "regex-syntax" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "reqwest" @@ -3100,7 +3493,7 @@ dependencies = [ "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.29", + "hyper 0.14.31", "hyper-tls", "ipnet", "js-sys", @@ -3128,12 +3521,9 @@ dependencies = [ [[package]] name = "rgb" -version = "0.8.37" +version = "0.8.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05aaa8004b64fd573fc9d002f4e632d51ad4f026c2b5ba95fcb6c2f32c2c47d8" -dependencies = [ - "bytemuck", -] +checksum = "57397d16646700483b67d2dd6511d79318f9d057fdbd21a4066aeac8b41d310a" [[package]] name = "ring" @@ -3167,9 +3557,9 @@ dependencies = [ [[package]] name = "rust-embed" -version = "8.4.0" +version = "8.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19549741604902eb99a7ed0ee177a0663ee1eda51a29f71401f166e47e77806a" +checksum = "fa66af4a4fdd5e7ebc276f115e895611a34739a9c1c01028383d612d550953c0" dependencies = [ "rust-embed-impl", "rust-embed-utils", @@ -3178,22 +3568,22 @@ dependencies = [ [[package]] name = "rust-embed-impl" -version = "8.4.0" +version = "8.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb9f96e283ec64401f30d3df8ee2aaeb2561f34c824381efa24a35f79bf40ee4" +checksum = "6125dbc8867951125eec87294137f4e9c2c96566e61bf72c45095a7c77761478" dependencies = [ "proc-macro2", "quote", "rust-embed-utils", - "syn 2.0.68", + "syn 2.0.85", "walkdir", ] [[package]] name = "rust-embed-utils" -version = "8.4.0" +version = "8.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38c74a686185620830701348de757fd36bef4aa9680fd23c49fc539ddcc1af32" +checksum = "2e5347777e9aacb56039b0e1f28785929a8a3b709e87482e7442c72e7c12529d" dependencies = [ "sha2", "walkdir", @@ -3213,18 +3603,18 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustc_version" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" dependencies = [ "semver", ] [[package]] name = "rustix" -version = "0.38.34" +version = "0.38.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" +checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811" dependencies = [ "bitflags 2.6.0", "errno", @@ -3261,9 +3651,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.10" +version = "0.23.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05cff451f60db80f490f3c182b77c35260baace73209e9cdbbe526bfe3a4d402" +checksum = "5fbb44d7acc4e873d613422379f69f237a1b141928c02f6bc6ccfddddc2d7993" dependencies = [ "aws-lc-rs", "log", @@ -3276,12 +3666,12 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.7.0" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f1fb85efa936c42c6d5fc28d2629bb51e4b2f4b8a5211e297d599cc5a093792" +checksum = "fcaf18a4f2be7326cd874a5fa579fae794320a0f388d365dca7e480e55f83f8a" dependencies = [ "openssl-probe", - "rustls-pemfile 2.1.2", + "rustls-pemfile 2.2.0", "rustls-pki-types", "schannel", "security-framework", @@ -3298,25 +3688,24 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "2.1.2" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" dependencies = [ - "base64 0.22.1", "rustls-pki-types", ] [[package]] name = "rustls-pki-types" -version = "1.7.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" +checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b" [[package]] name = "rustls-webpki" -version = "0.102.4" +version = "0.102.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff448f7e92e913c4b7d4c6d8e4540a1724b319b4152b8aef6d4cf8339712b33e" +checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" dependencies = [ "aws-lc-rs", "ring 0.17.8", @@ -3326,9 +3715,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" +checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" [[package]] name = "ryu" @@ -3347,11 +3736,11 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.23" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534" +checksum = "01227be5826fa0690321a2ba6c5cd57a19cf3f6a09e76973b58e61de6ab9d1c1" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -3360,6 +3749,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "scratch" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3cf7c11c38cb994f3d40e8a8cde3bbd1f72a435e4c49e85d6553d8312306152" + [[package]] name = "sct" version = "0.7.1" @@ -3372,9 +3767,9 @@ dependencies = [ [[package]] name = "security-framework" -version = "2.11.0" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ "bitflags 2.6.0", "core-foundation", @@ -3385,9 +3780,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.11.0" +version = "2.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "317936bbbd05227752583946b9e66d7ce3b489f84e11a94a510b4437fef407d7" +checksum = "ea4a292869320c0272d7bc55a5a6aafaff59b4f63404a003887b679a2e05b4b6" dependencies = [ "core-foundation-sys", "libc", @@ -3404,31 +3799,42 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.203" +version = "1.0.213" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" +checksum = "3ea7893ff5e2466df8d720bb615088341b295f849602c6956047f8f80f0e9bc1" dependencies = [ "serde_derive", ] +[[package]] +name = "serde_cbor" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5" +dependencies = [ + "half 1.8.3", + "serde", +] + [[package]] name = "serde_derive" -version = "1.0.203" +version = "1.0.213" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" +checksum = "7e85ad2009c50b58e87caa8cd6dac16bdf511bbfb7af6c33df902396aa480fa5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.85", ] [[package]] name = "serde_json" -version = "1.0.118" +version = "1.0.132" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d947f6b3163d8857ea16c4fa0dd4840d52f3041039a85decd46867eb1abef2e4" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" dependencies = [ "itoa", + "memchr", "ryu", "serde", ] @@ -3445,9 +3851,9 @@ dependencies = [ [[package]] name = "serde_spanned" -version = "0.6.6" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79e674e01f999af37c49f70a6ede167a8a60b2503e56c5599532a65baa5969a0" +checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1" dependencies = [ "serde", ] @@ -3502,9 +3908,9 @@ dependencies = [ [[package]] name = "signal-hook-mio" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29ad2e15f37ec9a6cc544097b78a1ec90001e9f71b81338ca39f430adaca99af" +checksum = "34db1a06d485c9142248b7a054f034b349b212551f3dfd19c94d45a754a217cd" dependencies = [ "libc", "mio", @@ -3602,6 +4008,12 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "strsim" version = "0.11.1" @@ -3610,24 +4022,24 @@ checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" [[package]] name = "strum" -version = "0.25.0" +version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" +checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" dependencies = [ "strum_macros", ] [[package]] name = "strum_macros" -version = "0.25.3" +version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck 0.4.1", + "heck 0.5.0", "proc-macro2", "quote", "rustversion", - "syn 2.0.68", + "syn 2.0.85", ] [[package]] @@ -3649,9 +4061,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.68" +version = "2.0.85" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "901fa70d88b9d6c98022e23b4136f9f3e54e4662c3bc1bd1d84a42a9a0f0c1e9" +checksum = "5023162dfcd14ef8f32034d8bcd4cc5ddc61ef7a247c024a33e24e1f24d21b56" dependencies = [ "proc-macro2", "quote", @@ -3672,15 +4084,16 @@ checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" [[package]] name = "sysinfo" -version = "0.30.12" +version = "0.30.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "732ffa00f53e6b2af46208fba5718d9662a421049204e156328b66791ffa15ae" +checksum = "0a5b4ddaee55fb2bea2bf0e5000747e5f5c0de765e5a5ff87f4cd106439f4bb3" dependencies = [ "cfg-if", "core-foundation-sys", "libc", "ntapi", "once_cell", + "rayon", "windows", ] @@ -3744,29 +4157,62 @@ dependencies = [ [[package]] name = "target-lexicon" -version = "0.12.14" +version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "tempfile" -version = "3.10.1" +version = "3.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" +checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b" dependencies = [ "cfg-if", "fastrand", + "once_cell", "rustix", - "windows-sys 0.52.0", + "windows-sys 0.59.0", +] + +[[package]] +name = "termcolor" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "text-generation-backends-trtllm" +version = "2.4.0" +dependencies = [ + "async-stream", + "async-trait", + "clap 4.5.20", + "cmake", + "cxx", + "cxx-build", + "hashbrown 0.14.5", + "hf-hub", + "log", + "pkg-config", + "text-generation-router", + "thiserror", + "tokenizers", + "tokio", + "tokio-stream", + "tracing", + "tracing-opentelemetry 0.25.0", + "tracing-subscriber", ] [[package]] name = "text-generation-benchmark" -version = "2.1.1" +version = "2.4.0" dependencies = [ "average", - "clap", - "crossterm", + "clap 4.5.20", "float-ord", "hf-hub", "ratatui", @@ -3783,7 +4229,7 @@ dependencies = [ [[package]] name = "text-generation-client" -version = "2.1.1" +version = "2.4.0" dependencies = [ "async-trait", "base64 0.22.1", @@ -3795,20 +4241,22 @@ dependencies = [ "tokio", "tonic 0.10.2", "tonic-build", - "tower", + "tower 0.4.13", "tracing", ] [[package]] name = "text-generation-launcher" -version = "2.1.1" +version = "2.4.0" dependencies = [ - "clap", + "clap 4.5.20", "ctrlc", "float_eq", "hf-hub", - "nix", + "nix 0.28.0", "once_cell", + "pyo3", + "regex", "reqwest", "serde", "serde_json", @@ -3820,13 +4268,15 @@ dependencies = [ [[package]] name = "text-generation-router" -version = "2.1.1" +version = "2.4.0" dependencies = [ "async-stream", - "axum 0.7.5", + "async-trait", + "axum 0.7.7", "axum-tracing-opentelemetry", "base64 0.22.1", - "clap", + "clap 4.5.20", + "csv", "futures", "futures-util", "hf-hub", @@ -3834,7 +4284,7 @@ dependencies = [ "init-tracing-opentelemetry", "itertools 0.10.5", "jsonschema", - "metrics 0.21.1", + "metrics", "metrics-exporter-prometheus", "minijinja", "minijinja-contrib", @@ -3843,12 +4293,13 @@ dependencies = [ "once_cell", "opentelemetry 0.20.0", "opentelemetry-otlp", + "pyo3", "rand", "regex", "reqwest", "serde", "serde_json", - "text-generation-client", + "sysinfo", "thiserror", "tokenizers", "tokio", @@ -3857,29 +4308,140 @@ dependencies = [ "tracing", "tracing-opentelemetry 0.21.0", "tracing-subscriber", + "ureq", "utoipa", "utoipa-swagger-ui", + "uuid", "vergen", ] +[[package]] +name = "text-generation-router-v2" +version = "2.4.0" +dependencies = [ + "async-stream", + "async-trait", + "axum 0.7.7", + "axum-tracing-opentelemetry", + "base64 0.22.1", + "clap 4.5.20", + "futures", + "futures-util", + "grpc-metadata", + "hf-hub", + "image", + "init-tracing-opentelemetry", + "jsonschema", + "metrics", + "metrics-exporter-prometheus", + "minijinja", + "minijinja-contrib", + "nohash-hasher", + "once_cell", + "opentelemetry 0.20.0", + "opentelemetry-otlp", + "prost 0.12.6", + "prost-build", + "rand", + "regex", + "reqwest", + "serde", + "serde_json", + "slotmap", + "text-generation-router", + "thiserror", + "tokenizers", + "tokio", + "tokio-stream", + "tonic 0.10.2", + "tonic-build", + "tower 0.4.13", + "tower-http", + "tracing", + "tracing-opentelemetry 0.21.0", + "tracing-subscriber", + "utoipa", + "utoipa-swagger-ui", +] + +[[package]] +name = "text-generation-router-v3" +version = "2.4.0" +dependencies = [ + "async-stream", + "async-trait", + "axum 0.7.7", + "axum-tracing-opentelemetry", + "base64 0.22.1", + "clap 4.5.20", + "criterion", + "futures", + "futures-util", + "grpc-metadata", + "hf-hub", + "image", + "init-tracing-opentelemetry", + "itertools 0.13.0", + "jsonschema", + "metrics", + "metrics-exporter-prometheus", + "minijinja", + "minijinja-contrib", + "nohash-hasher", + "once_cell", + "opentelemetry 0.20.0", + "opentelemetry-otlp", + "prost 0.12.6", + "prost-build", + "rand", + "regex", + "reqwest", + "serde", + "serde_json", + "slotmap", + "text-generation-router", + "thiserror", + "tokenizers", + "tokio", + "tokio-stream", + "tonic 0.10.2", + "tonic-build", + "tower 0.4.13", + "tower-http", + "tracing", + "tracing-opentelemetry 0.21.0", + "tracing-subscriber", + "utoipa", + "utoipa-swagger-ui", +] + +[[package]] +name = "textwrap" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" +dependencies = [ + "unicode-width", +] + [[package]] name = "thiserror" -version = "1.0.61" +version = "1.0.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" +checksum = "5d11abd9594d9b38965ef50805c5e469ca9cc6f197f883f717e0269a3057b3d5" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.61" +version = "1.0.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" +checksum = "ae71770322cbd277e69d762a16c444af02aa0575ac0d174f0b9562d3b37f8602" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.85", ] [[package]] @@ -3936,11 +4498,21 @@ dependencies = [ "time-core", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" -version = "1.6.1" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c55115c6fbe2d2bef26eb09ad74bde02d8255476fc0c7b515ef09fbb35742d82" +checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" dependencies = [ "tinyvec_macros", ] @@ -3953,9 +4525,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokenizers" -version = "0.19.1" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e500fad1dd3af3d626327e6a3fe5050e664a6eaa4708b8ca92f1794aaf73e6fd" +checksum = "b172ffa9a2e5c31bbddc940cd5725d933ced983a9333bbebc4c7eda3bbce1557" dependencies = [ "aho-corasick", "derive_builder", @@ -3974,7 +4546,7 @@ dependencies = [ "rayon", "rayon-cond", "regex", - "regex-syntax 0.8.4", + "regex-syntax 0.8.5", "serde", "serde_json", "spm_precompiled", @@ -3986,21 +4558,20 @@ dependencies = [ [[package]] name = "tokio" -version = "1.38.0" +version = "1.41.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba4f4a02a7a80d6f274636f0aa95c7e383b912d41fe721a31f29e29698585a4a" +checksum = "145f3413504347a2be84393cc8a7d2fb4d863b375909ea59f2158261aa258bbb" dependencies = [ "backtrace", "bytes", "libc", "mio", - "num_cpus", "parking_lot", "pin-project-lite", "signal-hook-registry", "socket2", "tokio-macros", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -4015,13 +4586,13 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.3.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" +checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.85", ] [[package]] @@ -4051,16 +4622,16 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.10", + "rustls 0.23.15", "rustls-pki-types", "tokio", ] [[package]] name = "tokio-stream" -version = "0.1.15" +version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" +checksum = "4f4e6ce100d0eb49a2734f8c0812bcd324cf357d21810932c5df6b96ef2b86f1" dependencies = [ "futures-core", "pin-project-lite", @@ -4069,9 +4640,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.11" +version = "0.7.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1" +checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" dependencies = [ "bytes", "futures-core", @@ -4083,9 +4654,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.14" +version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f49eb2ab21d2f26bd6db7bf383edc527a7ebaee412d17af4d40fdccd442f335" +checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e" dependencies = [ "serde", "serde_spanned", @@ -4095,20 +4666,20 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.6.6" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4badfd56924ae69bcc9039335b2e017639ce3f9b001c393c1b2d1ef846ce2cbf" +checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" dependencies = [ "serde", ] [[package]] name = "toml_edit" -version = "0.22.14" +version = "0.22.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f21c7aaf97f1bd9ca9d4f9e73b0a6c74bd5afef56f2bc931943a6e1c37e04e38" +checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" dependencies = [ - "indexmap 2.2.6", + "indexmap 2.6.0", "serde", "serde_spanned", "toml_datetime", @@ -4130,14 +4701,14 @@ dependencies = [ "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.29", + "hyper 0.14.31", "hyper-timeout", "percent-encoding", "pin-project", "prost 0.11.9", "tokio", "tokio-stream", - "tower", + "tower 0.4.13", "tower-layer", "tower-service", "tracing", @@ -4157,14 +4728,14 @@ dependencies = [ "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.29", + "hyper 0.14.31", "hyper-timeout", "percent-encoding", "pin-project", "prost 0.12.6", "tokio", "tokio-stream", - "tower", + "tower 0.4.13", "tower-layer", "tower-service", "tracing", @@ -4180,7 +4751,7 @@ dependencies = [ "proc-macro2", "prost-build", "quote", - "syn 2.0.68", + "syn 2.0.85", ] [[package]] @@ -4203,6 +4774,22 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2873938d487c3cfb9aed7546dc9f2711d867c9f90c46b889989a2cb84eba6b4f" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper 0.1.2", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "tower-http" version = "0.5.2" @@ -4212,7 +4799,7 @@ dependencies = [ "bitflags 2.6.0", "bytes", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "http-body-util", "pin-project-lite", "tower-layer", @@ -4221,15 +4808,15 @@ dependencies = [ [[package]] name = "tower-layer" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" [[package]] name = "tower-service" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" @@ -4251,7 +4838,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.85", ] [[package]] @@ -4317,7 +4904,25 @@ dependencies = [ "tracing-core", "tracing-log 0.2.0", "tracing-subscriber", - "web-time", + "web-time 0.2.4", +] + +[[package]] +name = "tracing-opentelemetry" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9784ed4da7d921bc8df6963f8c80a0e4ce34ba6ba76668acadd3edbd985ff3b" +dependencies = [ + "js-sys", + "once_cell", + "opentelemetry 0.24.0", + "opentelemetry_sdk 0.24.1", + "smallvec", + "tracing", + "tracing-core", + "tracing-log 0.2.0", + "tracing-subscriber", + "web-time 1.1.0", ] [[package]] @@ -4377,30 +4982,27 @@ checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unicase" -version = "2.7.0" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7d2d4dafb69621809a81864c9c1b864479e1235c0dd4e199924b9742439ed89" -dependencies = [ - "version_check", -] +checksum = "7e51b68083f157f853b6379db119d1c1be0e6e4dec98101079dec41f6f5cf6df" [[package]] name = "unicode-bidi" -version = "0.3.15" +version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" +checksum = "5ab17db44d7388991a428b2ee655ce0c212e862eff1768a455c58f9aad6e7893" [[package]] name = "unicode-ident" -version = "1.0.12" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" [[package]] name = "unicode-normalization" -version = "0.1.23" +version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" +checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" dependencies = [ "tinyvec", ] @@ -4416,15 +5018,26 @@ dependencies = [ [[package]] name = "unicode-segmentation" -version = "1.11.0" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" + +[[package]] +name = "unicode-truncate" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" +checksum = "b3644627a5af5fa321c95b9b235a72fd24cd29c648c2c379431e6628655627bf" +dependencies = [ + "itertools 0.13.0", + "unicode-segmentation", + "unicode-width", +] [[package]] name = "unicode-width" -version = "0.1.13" +version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" [[package]] name = "unicode_categories" @@ -4432,6 +5045,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" +[[package]] +name = "unindent" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" + [[package]] name = "untrusted" version = "0.7.1" @@ -4493,7 +5112,7 @@ version = "4.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c5afb1a60e207dca502682537fefcfd9921e71d0b83e9576060f09abc6efab23" dependencies = [ - "indexmap 2.2.6", + "indexmap 2.6.0", "serde", "serde_json", "utoipa-gen", @@ -4501,15 +5120,15 @@ dependencies = [ [[package]] name = "utoipa-gen" -version = "4.3.0" +version = "4.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7bf0e16c02bc4bf5322ab65f10ab1149bdbcaa782cba66dc7057370a3f8190be" +checksum = "20c24e8ab68ff9ee746aad22d39b5535601e6416d1b0feeabf78be986a5c4392" dependencies = [ "proc-macro-error", "proc-macro2", "quote", "regex", - "syn 2.0.68", + "syn 2.0.85", ] [[package]] @@ -4518,7 +5137,7 @@ version = "6.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b39868d43c011961e04b41623e050aedf2cc93652562ff7935ce0f819aaf2da" dependencies = [ - "axum 0.7.5", + "axum 0.7.7", "mime_guess", "regex", "rust-embed", @@ -4530,9 +5149,25 @@ dependencies = [ [[package]] name = "uuid" -version = "1.9.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5de17fd2f7da591098415cff336e12965a28061ddace43b59cb3c430179c9439" +checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" +dependencies = [ + "getrandom", + "rand", + "uuid-macro-internal", +] + +[[package]] +name = "uuid-macro-internal" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b91f57fe13a38d0ce9e28a03463d8d3c2468ed03d75375110ec71d93b449a08" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.85", +] [[package]] name = "v_frame" @@ -4559,9 +5194,9 @@ checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" [[package]] name = "vergen" -version = "8.3.1" +version = "8.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e27d6bdd219887a9eadd19e1c34f32e47fa332301184935c6d9bca26f3cca525" +checksum = "2990d9ea5967266ea0ccf413a4aa5c42a93dbcfda9cb49a97de6931726b12566" dependencies = [ "anyhow", "cargo_metadata", @@ -4581,9 +5216,9 @@ checksum = "852e951cb7832cb45cb1169900d19760cfa39b82bc0ea9c0e5a14ae88411c98b" [[package]] name = "version_check" -version = "0.9.4" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" [[package]] name = "walkdir" @@ -4612,34 +5247,35 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.92" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" dependencies = [ "cfg-if", + "once_cell", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.92" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.85", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.42" +version = "0.4.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76bc14366121efc8dbb487ab05bcc9d346b3b5ec0eaa76e46594cabbe51762c0" +checksum = "cc7ec4f8827a71586374db3e87abdb5a2bb3a15afed140221307c3ec06b1f63b" dependencies = [ "cfg-if", "js-sys", @@ -4649,9 +5285,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.92" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -4659,28 +5295,28 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.92" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.85", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.92" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" +checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" [[package]] name = "web-sys" -version = "0.3.69" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" +checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" dependencies = [ "js-sys", "wasm-bindgen", @@ -4696,6 +5332,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki" version = "0.22.4" @@ -4708,9 +5354,9 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.26.3" +version = "0.26.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd7c23921eeb1713a4e851530e9b9756e4fb0e89978582942612524cf09f01cd" +checksum = "841c67bff177718f1d4dfefde8d8f0e78f9b6589319ba88312f567fc5841a958" dependencies = [ "rustls-pki-types", ] @@ -4751,11 +5397,11 @@ checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" [[package]] name = "winapi-util" -version = "0.1.8" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -4771,7 +5417,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" dependencies = [ "windows-core", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -4780,7 +5426,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -4807,7 +5453,16 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", ] [[package]] @@ -4842,18 +5497,18 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm 0.52.5", - "windows_aarch64_msvc 0.52.5", - "windows_i686_gnu 0.52.5", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", "windows_i686_gnullvm", - "windows_i686_msvc 0.52.5", - "windows_x86_64_gnu 0.52.5", - "windows_x86_64_gnullvm 0.52.5", - "windows_x86_64_msvc 0.52.5", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", ] [[package]] @@ -4870,9 +5525,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_msvc" @@ -4888,9 +5543,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_i686_gnu" @@ -4906,15 +5561,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_msvc" @@ -4930,9 +5585,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_x86_64_gnu" @@ -4948,9 +5603,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnullvm" @@ -4966,9 +5621,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_msvc" @@ -4984,15 +5639,15 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.13" +version = "0.6.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59b5e5f6c299a3c7890b876a2a587f3115162487e704907d9b6cd29473052ba1" +checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" dependencies = [ "memchr", ] @@ -5009,22 +5664,23 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.7.34" +version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae87e3fcd617500e5d106f0380cf7b77f3c6092aae37191433159dda23cfb087" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ + "byteorder", "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.34" +version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.85", ] [[package]] @@ -5032,20 +5688,6 @@ name = "zeroize" version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" -dependencies = [ - "zeroize_derive", -] - -[[package]] -name = "zeroize_derive" -version = "1.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.68", -] [[package]] name = "zip" @@ -5076,9 +5718,9 @@ dependencies = [ [[package]] name = "zune-jpeg" -version = "0.4.11" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec866b44a2a1fd6133d363f073ca1b179f438f99e7e5bfb1e33f7181facfe448" +checksum = "16099418600b4d8f028622f73ff6e3deaabdff330fb9a2a131dea781ee8b0768" dependencies = [ "zune-core", ] diff --git a/Cargo.toml b/Cargo.toml index 35abe2a14a59589e0afeefb86bc366f699f5f224..07bc967e33cd108dd0a5773ee97b5c3e1148dde2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,23 +1,39 @@ [workspace] members = [ - "benchmark", - "router", - "router/client", - "router/grpc-metadata", - "launcher" + "benchmark", + "backends/v2", + "backends/v3", + "backends/grpc-metadata", + "backends/trtllm", + "launcher", + "router" +] +default-members = [ + "benchmark", + "backends/v2", + "backends/v3", + "backends/grpc-metadata", + # "backends/trtllm", + "launcher", + "router" ] resolver = "2" [workspace.package] -version = "2.1.1" +version = "2.4.0" edition = "2021" authors = ["Olivier Dehaene"] homepage = "https://github.com/huggingface/text-generation-inference" [workspace.dependencies] base64 = "0.22.0" -tokenizers = { version = "0.19.1", features = ["http"] } +tokenizers = { version = "0.20.0", features = ["http"] } hf-hub = { version = "0.3.1", features = ["tokio"] } +metrics = { version = "0.23.0" } +metrics-exporter-prometheus = { version = "0.15.1", features = [] } +minijinja = { version = "2.2.0", features = ["json"] } +minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } +pyo3 = { version = "0.22.2", features = ["auto-initialize"] } [profile.release] incremental = true diff --git a/Dockerfile b/Dockerfile index d4772b4a7221238cafbec05d278fa92fc3bb3c78..d4189c9f68dccda75ada2bdfa3dd164485287a1c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Rust builder -FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse @@ -11,11 +11,15 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher + RUN cargo chef prepare --recipe-path recipe.json FROM chef AS builder +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + python3.11-dev RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ @@ -28,22 +32,26 @@ RUN cargo chef cook --profile release-opt --recipe-path recipe.json ARG GIT_SHA ARG DOCKER_LABEL +COPY Cargo.lock Cargo.lock COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher -RUN cargo build --profile release-opt +RUN cargo build --profile release-opt --frozen # Python builder # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile -FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS pytorch-install +FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install -ARG PYTORCH_VERSION=2.3.0 -ARG PYTHON_VERSION=3.10 +# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099 +ARG PYTORCH_VERSION=2.4.0 + +ARG PYTHON_VERSION=3.11 # Keep in sync with `server/pyproject.toml -ARG CUDA_VERSION=12.1 +ARG CUDA_VERSION=12.4 ARG MAMBA_VERSION=24.3.0-0 ARG CUDA_CHANNEL=nvidia ARG INSTALL_CHANNEL=pytorch @@ -84,6 +92,7 @@ RUN case ${TARGETPLATFORM} in \ FROM pytorch-install AS kernel-builder ARG MAX_JOBS=8 +ENV TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0+PTX" RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ ninja-build cmake \ @@ -114,36 +123,29 @@ FROM kernel-builder AS exllama-kernels-builder WORKDIR /usr/src COPY server/exllama_kernels/ . -RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build +RUN python setup.py build # Build Transformers exllama kernels FROM kernel-builder AS exllamav2-kernels-builder WORKDIR /usr/src -COPY server/exllamav2_kernels/ . +COPY server/Makefile-exllamav2/ Makefile # Build specific version of transformers -RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build +RUN make build-exllamav2 # Build Transformers awq kernels FROM kernel-builder AS awq-kernels-builder WORKDIR /usr/src COPY server/Makefile-awq Makefile # Build specific version of transformers -RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-awq +RUN make build-awq # Build eetq kernels FROM kernel-builder AS eetq-kernels-builder WORKDIR /usr/src COPY server/Makefile-eetq Makefile # Build specific version of transformers -RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq - -# Build marlin kernels -FROM kernel-builder AS marlin-kernels-builder -WORKDIR /usr/src -COPY server/marlin/ . -# Build specific version of transformers -RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build +RUN make build-eetq # Build Lorax Punica kernels FROM kernel-builder AS lorax-punica-builder @@ -177,6 +179,12 @@ WORKDIR /usr/src COPY server/Makefile-selective-scan Makefile RUN make build-all +# Build flashinfer +FROM kernel-builder AS flashinfer-builder +WORKDIR /usr/src +COPY server/Makefile-flashinfer Makefile +RUN make install-flashinfer + # Text Generation Inference base image FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS base @@ -185,7 +193,7 @@ ENV PATH=/opt/conda/bin:$PATH \ CONDA_PREFIX=/opt/conda # Text Generation Inference base env -ENV HUGGINGFACE_HUB_CACHE=/data \ +ENV HF_HOME=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ PORT=80 @@ -203,33 +211,31 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins COPY --from=pytorch-install /opt/conda /opt/conda # Copy build artifacts from flash attention builder -COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages -COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages -COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages +COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages +COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages # Copy build artifacts from flash attention v2 builder -COPY --from=flash-att-v2-builder /opt/conda/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so /opt/conda/lib/python3.10/site-packages +COPY --from=flash-att-v2-builder /opt/conda/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so /opt/conda/lib/python3.11/site-packages # Copy build artifacts from custom kernels builder -COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages # Copy build artifacts from exllama kernels builder -COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages # Copy build artifacts from exllamav2 kernels builder -COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages # Copy build artifacts from awq kernels builder -COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages # Copy build artifacts from eetq kernels builder -COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages -# Copy build artifacts from marlin kernels builder -COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages -COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages - -# Copy builds artifacts from vllm builder -COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages - +COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages +# Copy build artifacts from lorax punica kernels builder +COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages +# Copy build artifacts from vllm builder +COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages # Copy build artifacts from mamba builder -COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages -COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages +COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /opt/conda/lib/python3.11/site-packages +COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /opt/conda/lib/python3.11/site-packages +COPY --from=flashinfer-builder /opt/conda/lib/python3.11/site-packages/flashinfer/ /opt/conda/lib/python3.11/site-packages/flashinfer/ # Install flash-attention dependencies RUN pip install einops --no-cache-dir @@ -241,7 +247,15 @@ COPY server/Makefile server/Makefile RUN cd server && \ make gen-server && \ pip install -r requirements_cuda.txt && \ - pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir + pip install ".[bnb, accelerate, marlin, moe, quantize, peft, outlines]" --no-cache-dir && \ + pip install nvidia-nccl-cu12==2.22.3 + +ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2 +# Required to find libpython within the rust binaries +ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/" +# This is needed because exl2 tries to load flash-attn +# And fails with our builds. +ENV EXLLAMA_NO_FLASH_ATTN=1 # Deps before the binaries # The binaries change on every build given we burn the SHA into them diff --git a/Dockerfile.nix b/Dockerfile.nix new file mode 100644 index 0000000000000000000000000000000000000000..f1e7e0f553e19614eb35c08856c6aaaf891b0782 --- /dev/null +++ b/Dockerfile.nix @@ -0,0 +1,24 @@ +# Build the image and get out the docker file: +# +# docker build -t tgi-nix-builder -f Dockerfile.nix +# docker run --log-driver=none tgi-nix-builder | docker load + +FROM nixos/nix:2.18.8 AS builder +RUN echo "experimental-features = nix-command flakes" >> /etc/nix/nix.conf +RUN nix profile install nixpkgs#cachix +RUN cachix use text-generation-inference +WORKDIR /root +ADD . . +RUN nix build . +RUN mkdir /tmp/nix-store-closure +RUN cp -R $(nix-store -qR result/) /tmp/nix-store-closure + +FROM ubuntu:24.04 + +WORKDIR /app + +# Copy /nix/store +COPY --from=builder /tmp/nix-store-closure /nix/store +COPY --from=builder /root/result /app +RUN ldconfig +CMD ["ldconfig", "/app/bin/text-generation-launcher"] diff --git a/Dockerfile_amd b/Dockerfile_amd index 0aebeee5747b4f55ce63ee4bcbf1fa353065366c..b84d4edd80247d2a31eadb8dfca688e1a23fd73f 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -1,5 +1,5 @@ # Rust builder -FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse @@ -11,11 +11,14 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo chef prepare --recipe-path recipe.json FROM chef AS builder +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + python3.11-dev RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ @@ -28,16 +31,18 @@ RUN cargo chef cook --profile release-opt --recipe-path recipe.json ARG GIT_SHA ARG DOCKER_LABEL +COPY Cargo.lock Cargo.lock COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher -RUN cargo build --profile release-opt +RUN cargo build --profile release-opt --frozen # Text Generation Inference base image for RoCm -FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update AS base +FROM rocm/dev-ubuntu-22.04:6.2 AS base RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ build-essential \ @@ -46,33 +51,34 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins curl \ git \ make \ + libmsgpack-dev \ libssl-dev \ + llvm-dev \ g++ \ # Needed to build VLLM & flash. rocthrust-dev \ hipsparse-dev \ hipblas-dev \ - hipblaslt-dev \ + hipcub-dev \ rocblas-dev \ hiprand-dev \ + hipfft-dev \ rocrand-dev \ miopen-hip-dev \ - hipfft-dev \ - hipcub-dev \ hipsolver-dev \ rccl-dev \ cmake \ - python3-dev && \ + python3.11-venv && \ rm -rf /var/lib/apt/lists/* # Keep in sync with `server/pyproject.toml ARG MAMBA_VERSION=23.1.0-1 -ARG PYTORCH_VERSION='2.3.0' -ARG ROCM_VERSION='6.0.2' -ARG PYTHON_VERSION='3.10.10' +ARG PYTHON_VERSION='3.11.10' # Automatically set by buildx ARG TARGETPLATFORM -ENV PATH /opt/conda/bin:$PATH +ENV PATH=/opt/conda/bin:$PATH + +ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" # TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. # Install mamba @@ -87,42 +93,141 @@ RUN chmod +x ~/mambaforge.sh && \ mamba init && \ rm ~/mambaforge.sh -# Install flash-attention, torch dependencies -RUN pip install numpy einops ninja --no-cache-dir - -RUN conda install intel::mkl-static intel::mkl-include -RUN pip uninstall -y triton && \ - git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \ - cd triton/python && \ - pip install . - -RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir +# RUN conda install intel::mkl-static intel::mkl-include +# Install pytorch +# On arm64 we exit with an error code +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") exit 1 ;; \ + *) /opt/conda/bin/conda update -y conda && \ + /opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \ + esac && \ + /opt/conda/bin/conda clean -ya -ARG _GLIBCXX_USE_CXX11_ABI="1" -ARG CMAKE_PREFIX_PATH="/opt/conda" +# Install flash-attention, torch dependencies +RUN python3 -m pip install --upgrade pip && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/* + +RUN conda install mkl=2021 +ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/ + + +ARG COMMON_WORKDIR=/ +WORKDIR ${COMMON_WORKDIR} + + +# Install HIPBLASLt +FROM base AS build_hipblaslt +ARG HIPBLASLT_BRANCH="e6da924" +RUN git clone https://github.com/ROCm/hipBLASLt.git \ + && cd hipBLASLt \ + && git checkout ${HIPBLASLT_BRANCH} \ + && SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \ + && cd build/release \ + && make package + +FROM scratch AS export_hipblaslt +ARG COMMON_WORKDIR +COPY --from=build_hipblaslt ${COMMON_WORKDIR}/hipBLASLt/build/release/*.deb / + +# RCCL build stages +FROM base AS build_rccl +ARG RCCL_BRANCH="rocm-6.2.0" +RUN git clone https://github.com/ROCm/rccl \ + && cd rccl \ + && git checkout ${RCCL_BRANCH} \ + && ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH} +FROM scratch AS export_rccl +ARG COMMON_WORKDIR +COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb / + +# Triton build stages +FROM base AS build_triton +ARG TRITON_BRANCH="e192dba" +ARG TRITON_REPO="https://github.com/triton-lang/triton.git" +RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \ + && cd triton \ + && git checkout ${TRITON_BRANCH} \ + && cd python \ + && python3 setup.py bdist_wheel --dist-dir=dist +FROM scratch AS export_triton +ARG COMMON_WORKDIR +COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl / + +# # AMD-SMI build stages +FROM base AS build_amdsmi +RUN cd /opt/rocm/share/amd_smi \ + && pip wheel . --wheel-dir=dist +FROM scratch AS export_amdsmi +COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl / + + +FROM base as build_pytorch + +RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \ + if ls /install/*.deb; then \ + dpkg -i /install/*.deb \ + && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ + && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \ + fi + +ARG BUILD_ENVIRONMENT=pytorch-linux-jammy-rocm6.2-py3.11 ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" -ARG BUILD_CAFFE2="0" \ - BUILD_CAFFE2_OPS="0" \ - USE_CUDA="0" \ - USE_ROCM="1" \ - BUILD_TEST="0" \ - USE_FBGEMM="0" \ - USE_NNPACK="0" \ - USE_QNNPACK="0" \ - USE_XNNPACK="0" \ - USE_FLASH_ATTENTION="1" \ - USE_MEM_EFF_ATTENTION="0" - -RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install -# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm -ENV HIP_FORCE_DEV_KERNARG=1 - -# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK. -# However, Triton requires a tunning for each prompt length, which is prohibitive. -ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0 - -FROM base AS kernel-builder +# A commit to fix the output scaling factor issue in _scaled_mm +# Not yet in 2.5.0-rc1 +ARG PYTORCH_BRANCH="cedc116" +ARG PYTORCH_VISION_BRANCH="v0.19.1" +ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git" + +RUN git clone ${PYTORCH_REPO} pytorch \ + && cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \ + && pip install -r requirements.txt --no-cache-dir \ + && python tools/amd_build/build_amd.py \ + && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist +FROM scratch as export_pytorch +ARG COMMON_WORKDIR +COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl / + +FROM base AS install_deps + +ARG COMMON_WORKDIR + +# Install hipblaslt +RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \ + if ls /install/*.deb; then \ + dpkg -i /install/*.deb \ + && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ + && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \ + fi + +RUN --mount=type=bind,from=export_rccl,src=/,target=/install \ + if ls /install/*.deb; then \ + dpkg -i /install/*.deb \ + # RCCL needs to be installed twice + && dpkg -i /install/*.deb \ + && sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \ + && sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status; \ + fi + +RUN --mount=type=bind,from=export_triton,src=/,target=/install \ + if ls /install/*.whl; then \ + # Preemptively uninstall to prevent pip same-version no-installs + pip uninstall -y triton \ + && pip install /install/*.whl; \ + fi + +RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \ + # Preemptively uninstall to prevent pip same-version no-installs + pip uninstall -y amdsmi \ + && pip install /install/*.whl; + +RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \ + if ls /install/*.whl; then \ + # Preemptively uninstall to prevent pip same-version no-installs + pip uninstall -y torch torchvision \ + && pip install /install/*.whl; \ + fi + +FROM install_deps AS kernel-builder # # Build vllm kernels FROM kernel-builder AS vllm-builder @@ -162,27 +267,27 @@ COPY server/exllamav2_kernels/ . RUN python setup.py build -FROM base AS base-copy +FROM install_deps AS base-copy # Text Generation Inference base env -ENV HUGGINGFACE_HUB_CACHE=/data \ +ENV HF_HOME=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ PORT=80 # Copy builds artifacts from vllm builder -COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages # Copy build artifacts from flash attention v2 builder -COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages # Copy build artifacts from custom kernels builder -COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages # Copy build artifacts from exllama kernels builder -COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages # Copy build artifacts from exllamav2 kernels builder -COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages # Install server COPY proto proto @@ -199,6 +304,7 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/l COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router # Install launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher +ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/" # AWS Sagemaker compatible image FROM base AS sagemaker @@ -211,6 +317,20 @@ ENTRYPOINT ["./entrypoint.sh"] # Final image FROM base-copy +# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm +ENV HIP_FORCE_DEV_KERNARG=1 + +# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK. +# However, Triton requires a tunning for each prompt length, which is prohibitive. +ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0 +ENV ROCM_USE_CUSTOM_PAGED_ATTN=1 +ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0 +ENV VLLM_MOE_PADDING=0 +ENV ATTENTION=paged +ENV PREFIX_CACHING=0 +ENV PREFILL_CHUNKING=0 +ENV ROCM_USE_SKINNY_GEMM=1 + COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh diff --git a/Dockerfile_dcu b/Dockerfile_dcu deleted file mode 100644 index e0d1d2c36f0eede67f6d24acfeabd642ede780a7..0000000000000000000000000000000000000000 --- a/Dockerfile_dcu +++ /dev/null @@ -1,130 +0,0 @@ -# Rust builder -FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-centos7.6-dtk24.04-py310 as chef -RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y -ENV PATH /root/.cargo/bin:$PATH -RUN cargo install cargo-chef -WORKDIR /usr/src - -ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse - -FROM chef as planner -COPY Cargo.toml Cargo.toml -COPY Cargo.lock Cargo.lock -COPY rust-toolchain.toml rust-toolchain.toml -COPY proto proto -COPY benchmark benchmark -COPY router router -COPY launcher launcher -RUN cargo chef prepare --recipe-path recipe.json - -FROM chef AS builder - -ARG GIT_SHA -ARG DOCKER_LABEL - -RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ - curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ - unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ - unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ - rm -f $PROTOC_ZIP -COPY --from=planner /usr/src/recipe.json recipe.json -RUN cargo chef cook --release --recipe-path recipe.json - -COPY Cargo.toml Cargo.toml -COPY Cargo.lock Cargo.lock -COPY rust-toolchain.toml rust-toolchain.toml -COPY proto proto -COPY benchmark benchmark -COPY router router -COPY launcher launcher -RUN cargo build --release - -# Text Generation Inference base image for RoCm -FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-centos7.6-dtk24.04-py310 as base -# Need hyhal while compiling -WORKDIR /opt -RUN wget https://cancon.hpccube.com:65024/directlink/1/DTK-23.10.1/hyhal.tar.gz && \ - tar -xzf hyhal.tar.gz -C /opt - -ENV LD_LIBRARY_PATH /opt/hyhal/lib:/opt/hyhal/lib64:$LD_LIBRARY_PATH -ENV PYTHONPATH /usr/local/lib/python3.10/site-packages:$PYTHONPATH - -FROM base AS kernel-builder - -# Build vllm kernels -FROM kernel-builder AS vllm-builder -WORKDIR /usr/src -COPY server/vllm/ . - -# Build specific version of vllm -RUN python setup.py build - -# Build Transformers CUDA kernels (gpt-neox and bloom) -FROM kernel-builder as custom-kernels-builder -WORKDIR /usr/src -COPY server/custom_kernels/ . -RUN python setup.py build - -# Build exllama kernels -FROM kernel-builder as exllama-kernels-builder -WORKDIR /usr/src -COPY server/exllama_kernels/ . - -RUN python setup.py build - -# Build exllama v2 kernels -FROM kernel-builder as exllamav2-kernels-builder -WORKDIR /usr/src -COPY server/exllamav2_kernels/ . - -RUN python setup.py build - -FROM base as base-copy - -# uninstall exist vllm in base docker image -RUN pip uninstall -y vllm - -# Copy builds artifacts from vllm builder -COPY --from=vllm-builder /usr/src/build/lib.linux-x86_64-cpython-310 /usr/local/lib/python3.10/site-packages - -# Copy build artifacts from custom kernels builder -COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /usr/local/lib/python3.10/site-packages - -# Copy build artifacts from exllama kernels builder -COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /usr/local/lib/python3.10/site-packages - -# Copy build artifacts from exllamav2 kernels builder -COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /usr/local/lib/python3.10/site-packages - -# Install server -RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple -COPY proto proto -COPY server server -COPY server/Makefile server/Makefile -RUN cd server && \ - make gen-server && \ - pip install -r requirements_rocm.txt && \ - pip install ".[accelerate, peft, outlines]" --no-cache-dir - -# Install benchmarker -COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark -# Install router -COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router -# Install launcher -COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher - -#Remove default hyhal -RUN rm -rf /opt/hyhal /opt/hyhal.tar.gz - -# AWS Sagemaker compatible image -# FROM base-copy as sagemaker -# COPY sagemaker-entrypoint.sh entrypoint.sh -# RUN chmod +x entrypoint.sh - -# ENTRYPOINT ["./entrypoint.sh"] - -# # Final image -# FROM base-copy - -# ENTRYPOINT ["text-generation-launcher"] -# CMD ["--json-output"] diff --git a/Dockerfile_intel b/Dockerfile_intel index 6a803a32bacd0bf5a6fa8f9ffeabfe8506ff7fc2..96f242489abe7a72c898d3678179576fcfa2138e 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -1,6 +1,6 @@ ARG PLATFORM=xpu -FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse @@ -12,11 +12,14 @@ COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher RUN cargo chef prepare --recipe-path recipe.json FROM chef AS builder +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + python3.11-dev RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ @@ -29,20 +32,48 @@ RUN cargo chef cook --profile release-opt --recipe-path recipe.json ARG GIT_SHA ARG DOCKER_LABEL +COPY Cargo.lock Cargo.lock COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router +COPY backends backends COPY launcher launcher -RUN cargo build --profile release-opt +RUN cargo build --profile release-opt --frozen # Text Generation Inference base image for Intel -FROM intel/intel-extension-for-pytorch:2.1.30-xpu AS xpu +FROM intel/intel-extension-for-pytorch:2.3.110-xpu AS xpu USER root + +ARG MAMBA_VERSION=23.1.0-1 +ARG PYTHON_VERSION='3.11.10' +# Automatically set by buildx +ARG TARGETPLATFORM +ENV PATH=/opt/conda/bin:$PATH + +# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. +# Install mamba +# translating Docker's TARGETPLATFORM into mamba arches +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") MAMBA_ARCH=aarch64 ;; \ + *) MAMBA_ARCH=x86_64 ;; \ + esac && \ + curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" +RUN chmod +x ~/mambaforge.sh && \ + bash ~/mambaforge.sh -b -p /opt/conda && \ + rm ~/mambaforge.sh + +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") exit 1 ;; \ + *) /opt/conda/bin/conda update -y conda && \ + /opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \ + esac && \ + /opt/conda/bin/conda clean -ya + # libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \ dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb @@ -52,18 +83,16 @@ RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dea RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list -RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build pciutils +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y intel-basekit xpu-smi cmake ninja-build pciutils # Text Generation Inference base env -ENV HUGGINGFACE_HUB_CACHE=/data \ +ENV HF_HOME=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ PORT=80 WORKDIR /usr/src -RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl -RUN pip install https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed +RUN pip install torch==2.3.1+cxx11.abi torchvision==0.18.1+cxx11.abi torchaudio==2.3.1+cxx11.abi intel-extension-for-pytorch==2.3.110+xpu oneccl_bind_pt==2.3.100+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --no-cache-dir # Install server COPY proto proto @@ -78,13 +107,13 @@ ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib -ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64: -ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin +ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:/opt/conda/lib +ENV PATH=/opt/conda/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV CCL_ZE_IPC_EXCHANGE=sockets ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include - -RUN pip uninstall -y intel-extension-for-pytorch && cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch +ENV TORCH_LLM_ALLREDUCE=1 +ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0 # Install benchmarker COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark @@ -101,17 +130,28 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins curl \ ca-certificates \ make \ - g++ \ + g++-12 \ + gcc-12 \ git \ wget \ - cmake + cmake \ + libnuma-dev + +RUN update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-12 12 +RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 12 +RUN update-alternatives --install /usr/bin/cc cc /usr/bin/gcc 30 +RUN update-alternatives --set cc /usr/bin/gcc + +RUN update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++ 30 +RUN update-alternatives --set c++ /usr/bin/g++ + ENV HUGGINGFACE_HUB_CACHE=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ PORT=80 ARG MAMBA_VERSION=23.1.0-1 -ARG PYTHON_VERSION='3.10.10' +ARG PYTHON_VERSION='3.11.10' # Automatically set by buildx ARG TARGETPLATFORM ENV PATH /opt/conda/bin:$PATH @@ -128,33 +168,37 @@ RUN chmod +x ~/mambaforge.sh && \ bash ~/mambaforge.sh -b -p /opt/conda && \ rm ~/mambaforge.sh +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") exit 1 ;; \ + *) /opt/conda/bin/conda update -y conda && \ + /opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \ + esac && \ + /opt/conda/bin/conda clean -ya + RUN conda install -c conda-forge gperftools mkl -RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl -RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl -RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl -RUN pip install triton -WORKDIR /usr/src +RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.5.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl +RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.20.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl +RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl -RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout eda7a7c42df6f9a64e0de9c2b69304ee02f2c32a +RUN pip install triton py-libnuma + +WORKDIR /usr/src -RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout ccl_torch_dev_0131 +RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout f86e93e4890dc2c989024d148d415c9aa8a1649f +RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout v2.4.0+cpu+rc0 RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install . -ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so:/opt/conda/lib/libiomp5.so -ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch -ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch -ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric -ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib -ENV KMP_BLOCKTIME=1 -ENV KMP_TPAUSE=0 -ENV KMP_FORKJOIN_BARRIER_PATTERN=dist,dist -ENV KMP_PLAIN_BARRIER_PATTERN=dist,dist -ENV KMP_REDUCTION_BARRIER_PATTERN=dist,dist +ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so +ENV CCL_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch +ENV I_MPI_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch +ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric +ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/lib +ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/" # Install server COPY proto proto @@ -173,5 +217,9 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher FROM ${PLATFORM} AS final +ENV ATTENTION=paged +ENV PREFIX_CACHING=0 +ENV PREFILL_CHUNKING=0 +ENV CUDA_GRAPHS=0 ENTRYPOINT ["text-generation-launcher"] CMD ["--json-output"] diff --git a/Dockerfile_trtllm b/Dockerfile_trtllm new file mode 100644 index 0000000000000000000000000000000000000000..3ccb0310bea77eb9f8c8f60a982e46753e66847e --- /dev/null +++ b/Dockerfile_trtllm @@ -0,0 +1,108 @@ +ARG CUDA_ARCH_LIST="75-real;80-real;86-real;89-real;90-real" +ARG OMPI_VERSION="4.1.6" + +# Build dependencies resolver stage +FROM lukemathwalker/cargo-chef:latest AS chef +WORKDIR /usr/src/text-generation-inference/backends/trtllm + +FROM chef AS planner +COPY . . +RUN cargo chef prepare --recipe-path recipe.json + +# CUDA dependent dependencies resolver stage +FROM nvidia/cuda:12.6.1-cudnn-devel-ubuntu22.04 AS cuda-builder + +RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ + --mount=type=cache,target=/var/lib/apt,sharing=locked \ + apt update && apt install -y \ + build-essential \ + cmake \ + curl \ + gcc \ + g++ \ + git \ + git-lfs \ + libssl-dev \ + ninja-build \ + pkg-config \ + python3 \ + python3-dev \ + python3-setuptools \ + tar \ + wget + +ENV TGI_INSTALL_PREFIX=/usr/local/tgi +ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt + +# Install OpenMPI +FROM cuda-builder AS mpi-builder +ARG OMPI_VERSION + +ENV OMPI_TARBALL_FILENAME="openmpi-$OMPI_VERSION.tar.bz2" +RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILENAME" -P /opt/src && \ + mkdir /usr/src/mpi && \ + tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \ + cd /usr/src/mpi && \ + ./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \ + make -j all && \ + make install && \ + rm -rf "/opt/src/$OMPI_TARBALL_FILENAME" + +# Install TensorRT +FROM cuda-builder AS trt-builder +COPY backends/trtllm/scripts/install_tensorrt.sh /opt/install_tensorrt.sh +RUN chmod +x /opt/install_tensorrt.sh && \ + /opt/install_tensorrt.sh + +# Build Backend +FROM cuda-builder AS tgi-builder +WORKDIR /usr/src/text-generation-inference + +# Install Rust +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \ + chmod -R a+w /root/.rustup && \ + chmod -R a+w /root/.cargo + +ENV PATH="/root/.cargo/bin:$PATH" +RUN cargo install cargo-chef + +# Cache dependencies +COPY --from=planner /usr/src/text-generation-inference/backends/trtllm/recipe.json . +RUN cargo chef cook --release --recipe-path recipe.json + +# Build actual TGI +ARG CUDA_ARCH_LIST +ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt:$CMAKE_PREFIX_PATH" +ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH" +ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig:$PKG_CONFIG_PATH" + +COPY . . +COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt +COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi +RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \ + cd backends/trtllm && \ + CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release + +FROM nvidia/cuda:12.6.1-cudnn-runtime-ubuntu22.04 AS runtime +RUN apt update && apt install -y python3-minimal python3-dev python3-pip && \ + rm -rf /var/lib/{apt,dpkg,cache,log}/ && \ + python3 -m pip install transformers tokenizers + +WORKDIR /usr/local/tgi/bin + +ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/mpi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH" +ENV TOKENIZERS_PARALLELISM=false +ENV OMPI_MCA_plm_rsh_agent="" + +COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi +COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt +COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi +COPY --from=tgi-builder /usr/src/text-generation-inference/target/release/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher + +FROM runtime + +LABEL co.huggingface.vendor="Hugging Face Inc." +LABEL org.opencontainers.image.authors="hardware@hf.co" + +ENTRYPOINT ["./text-generation-launcher"] +CMD ["--executor-worker", "/usr/local/tgi/bin/executorWorker"] diff --git a/Makefile b/Makefile index a434d4f4ac8b150512fcdeb94f470a18e965adeb..3068a06f41b80be82e81ab44b768e5a113828a57 100644 --- a/Makefile +++ b/Makefile @@ -5,13 +5,13 @@ install-server-cpu: cd server && make install-server install-router: - cd router && cargo install --path . --debug + cargo install --path backends/v3/ install-launcher: - cd launcher && cargo install --path . + cargo install --path launcher/ install-benchmark: - cd benchmark && cargo install --path . + cargo install --path benchmark/ install: install-server install-router install-launcher diff --git a/README.md b/README.md index a4df89802d73852c50300077b902b63a107be20c..7ab00190203e9a77ab124b061ca6bc50fe62eccc 100644 --- a/README.md +++ b/README.md @@ -1,47 +1,214 @@ -
Text Generation Inference
+
-## 简介 -Text Generation Inference(TGI)是一个用 Rust 和 Python 编写的框架,用于部署和提供LLM模型的推理服务。TGI为很多大模型提供了高性能的推理服务,如LLama,Falcon,BLOOM,Baichuan,Qwen等。 + + Making TGI deployment optimal + -## 支持模型结构列表 -| 模型 | 模型并行 | FP16 | -| :----------: | :------: | :--: | -| LLaMA | Yes | Yes | -| LLaMA-2 | Yes | Yes | -| LLaMA-2-GPTQ | Yes | Yes | -| LLaMA-3 | Yes | Yes | -| Codellama | Yes | Yes | -| QWen2 | Yes | Yes | -| QWen2-GPTQ | Yes | Yes | -| Baichuan-7B | Yes | Yes | -| Baichuan2-7B | Yes | Yes | -| Baichuan2-13B | Yes | Yes | +# Text Generation Inference + + GitHub Repo stars + + + Swagger API documentation + -## 环境要求 -+ Python 3.10 -+ DTK 24.04.2 -+ torch 2.1.0 +A Rust, Python and gRPC server for text generation inference. Used in production at [Hugging Face](https://huggingface.co) +to power Hugging Chat, the Inference API and Inference Endpoint. -### 使用源码编译方式安装 +
-#### 编译环境准备 +## Table of contents -有两种方式安装准备环境 -##### 方式一: + - [Get Started](#get-started) + - [Docker](#docker) + - [API documentation](#api-documentation) + - [Using a private or gated model](#using-a-private-or-gated-model) + - [A note on Shared Memory (shm)](#a-note-on-shared-memory-shm) + - [Distributed Tracing](#distributed-tracing) + - [Architecture](#architecture) + - [Local install](#local-install) + - [Optimized architectures](#optimized-architectures) + - [Run locally](#run-locally) + - [Run](#run) + - [Quantization](#quantization) + - [Develop](#develop) + - [Testing](#testing) -### **TODO** +Text Generation Inference (TGI) is a toolkit for deploying and serving Large Language Models (LLMs). TGI enables high-performance text generation for the most popular open-source LLMs, including Llama, Falcon, StarCoder, BLOOM, GPT-NeoX, and [more](https://huggingface.co/docs/text-generation-inference/supported_models). TGI implements many features, such as: -##### 方式二: +- Simple launcher to serve most popular LLMs +- Production ready (distributed tracing with Open Telemetry, Prometheus metrics) +- Tensor Parallelism for faster inference on multiple GPUs +- Token streaming using Server-Sent Events (SSE) +- Continuous batching of incoming requests for increased total throughput +- [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) compatible with Open AI Chat Completion API +- Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures +- Quantization with : + - [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) + - [GPT-Q](https://arxiv.org/abs/2210.17323) + - [EETQ](https://github.com/NetEase-FuXi/EETQ) + - [AWQ](https://github.com/casper-hansen/AutoAWQ) + - [Marlin](https://github.com/IST-DASLab/marlin) + - [fp8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) +- [Safetensors](https://github.com/huggingface/safetensors) weight loading +- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) +- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor)) +- Stop sequences +- Log probabilities +- [Speculation](https://huggingface.co/docs/text-generation-inference/conceptual/speculation) ~2x latency +- [Guidance/JSON](https://huggingface.co/docs/text-generation-inference/conceptual/guidance). Specify output format to speed up inference and make sure the output is valid according to some specs.. +- Custom Prompt Generation: Easily generate text by providing custom prompts to guide the model's output +- Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance -基于光源pytorch2.1.0基础镜像环境:镜像下载地址:[https://sourcefind.cn/#/image/dcu/pytorch](https://sourcefind.cn/#/image/dcu/pytorch),根据pytorch2.1.0、python、dtk及系统下载对应的镜像版本。pytorch2.1.0镜像里已经安装了trition,flash-attn +### Hardware support + +- [Nvidia](https://github.com/huggingface/text-generation-inference/pkgs/container/text-generation-inference) +- [AMD](https://github.com/huggingface/text-generation-inference/pkgs/container/text-generation-inference) (-rocm) +- [Inferentia](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference) +- [Intel GPU](https://github.com/huggingface/text-generation-inference/pull/1475) +- [Gaudi](https://github.com/huggingface/tgi-gaudi) +- [Google TPU](https://huggingface.co/docs/optimum-tpu/howto/serving) + + +## Get Started + +### Docker + +For a detailed starting guide, please see the [Quick Tour](https://huggingface.co/docs/text-generation-inference/quicktour). The easiest way of getting started is using the official Docker container: + +```shell +model=HuggingFaceH4/zephyr-7b-beta +# share a volume with the Docker container to avoid downloading weights every run +volume=$PWD/data + +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ + ghcr.io/huggingface/text-generation-inference:2.4.0 --model-id $model +``` + +And then you can make requests like + +```bash +curl 127.0.0.1:8080/generate_stream \ + -X POST \ + -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ + -H 'Content-Type: application/json' +``` + +You can also use [TGI's Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) to obtain Open AI Chat Completion API compatible responses. + +```bash +curl localhost:8080/v1/chat/completions \ + -X POST \ + -d '{ + "model": "tgi", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What is deep learning?" + } + ], + "stream": true, + "max_tokens": 20 +}' \ + -H 'Content-Type: application/json' +``` + +**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. + +**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.4.0-rocm --model-id $model` instead of the command above. + +To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli): +``` +text-generation-launcher --help +``` + +### API documentation + +You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route. +The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference). + +### Using a private or gated model + +You have the option to utilize the `HF_TOKEN` environment variable for configuring the token employed by +`text-generation-inference`. This allows you to gain access to protected resources. + +For example, if you want to serve the gated Llama V2 model variants: + +1. Go to https://huggingface.co/settings/tokens +2. Copy your cli READ token +3. Export `HF_TOKEN=` + +or with Docker: + +```shell +model=meta-llama/Meta-Llama-3.1-8B-Instruct +volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run +token= + +docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.4.0 --model-id $model +``` + +### A note on Shared Memory (shm) + +[`NCCL`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html) is a communication framework used by +`PyTorch` to do distributed training/inference. `text-generation-inference` make +use of `NCCL` to enable Tensor Parallelism to dramatically speed up inference for large language models. + +In order to share data between the different devices of a `NCCL` group, `NCCL` might fall back to using the host memory if +peer-to-peer using NVLink or PCI is not possible. + +To allow the container to use 1G of Shared Memory and support SHM sharing, we add `--shm-size 1g` on the above command. + +If you are running `text-generation-inference` inside `Kubernetes`. You can also add Shared Memory to the container by +creating a volume with: + +```yaml +- name: shm + emptyDir: + medium: Memory + sizeLimit: 1Gi +``` + +and mounting it to `/dev/shm`. + +Finally, you can also disable SHM sharing by using the `NCCL_SHM_DISABLE=1` environment variable. However, note that +this will impact performance. + +### Distributed Tracing + +`text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature +by setting the address to an OTLP collector with the `--otlp-endpoint` argument. The default service name can be +overridden with the `--otlp-service-name` argument + +### Architecture + +![TGI architecture](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/TGI.png) + +Detailed blogpost by Adyen on TGI inner workings: [LLM inference at scale with TGI (Martin Iglesias Goyanes - Adyen, 2024)](https://www.adyen.com/knowledge-hub/llm-inference-at-scale-with-tgi) + +### Local install + +You can also opt to install `text-generation-inference` locally. + +First [install Rust](https://rustup.rs/) and create a Python virtual environment with at least +Python 3.9, e.g. using `conda`: -1. 安装Rust ```shell curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh + +conda create -n text-generation-inference python=3.11 +conda activate text-generation-inference ``` -2. 安装Protoc +You may also need to install Protoc. + +On Linux: + ```shell PROTOC_ZIP=protoc-21.12-linux-x86_64.zip curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP @@ -49,50 +216,77 @@ sudo unzip -o $PROTOC_ZIP -d /usr/local bin/protoc sudo unzip -o $PROTOC_ZIP -d /usr/local 'include/*' rm -f $PROTOC_ZIP ``` -3. 安装TGI Service -```bash -git clone http://developer.hpccube.com/codes/wangkx1/text_generation_server-dcu.git # 根据需要的分支进行切换 -cd text-generation-inference -pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple -pip install -r pre_requirements.txt +On MacOS, using Homebrew: -#安装exllama -cd server -make install-exllama #安装exllama kernels -make install-exllamav2 #安装exllmav2 kernels +```shell +brew install protobuf +``` -cd .. #回到项目根目录 -source $HOME/.cargo/env -BUILD_EXTENSIONS=True make install #安装text-generation服务 +Then run: +```shell +BUILD_EXTENSIONS=True make install # Install repository and HF/transformer fork with CUDA kernels +text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 ``` -4. 安装benchmark -```bash -cd text-generation-inference -make install-benchmark -``` -注意:若安装过程过慢,可以通过如下命令修改默认源提速。 -```bash +**Note:** on some machines, you may also need the OpenSSL libraries and gcc. On Linux machines, run: + +```shell +sudo apt-get install libssl-dev gcc -y ``` -另外,`cargo install` 太慢也可以通过在`~/.cargo/config`中添加源来提速。 -## 查看安装的版本号 -```bash -text-generation-launcher -V #版本号与官方版本同步 +## Optimized architectures + +TGI works out of the box to serve optimized models for all modern models. They can be found in [this list](https://huggingface.co/docs/text-generation-inference/supported_models). + +Other architectures are supported on a best-effort basis using: + +`AutoModelForCausalLM.from_pretrained(, device_map="auto")` + +or + +`AutoModelForSeq2SeqLM.from_pretrained(, device_map="auto")` + + + +## Run locally + +### Run + +```shell +text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 ``` -## 使用前 +### Quantization -```bash -export PYTORCH_TUNABLEOP_ENABLED=0 +You can also run pre-quantized weights (AWQ, GPTQ, Marlin) or on-the-fly quantize weights with bitsandbytes, EETQ, fp8, to reduce the VRAM requirement: + +```shell +text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 --quantize ``` -## Known Issue +4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`. + +Read more about quantization in the [Quantization documentation](https://huggingface.co/docs/text-generation-inference/en/conceptual/quantization). + +## Develop -- 无 +```shell +make server-dev +make router-dev +``` + +## Testing -## 参考资料 -- [README_ORIGIN](README_ORIGIN.md) -- [https://github.com/huggingface/text-generation-inference](https://github.com/huggingface/text-generation-inference) +```shell +# python +make python-server-tests +make python-client-tests +# or both server and client tests +make python-tests +# rust cargo tests +make rust-tests +# integration tests +make integration-tests +``` diff --git a/README_ORINGIN.md b/README_ORINGIN.md deleted file mode 100644 index 74616748efa88a47785eaff7c1952370aaad20ea..0000000000000000000000000000000000000000 --- a/README_ORINGIN.md +++ /dev/null @@ -1,258 +0,0 @@ -
- - - Making TGI deployment optimal - - -# Text Generation Inference - - - GitHub Repo stars - - - Swagger API documentation - - -A Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co) -to power Hugging Chat, the Inference API and Inference Endpoint. - -
- -## Table of contents - -- [Get Started](#get-started) - - [API Documentation](#api-documentation) - - [Using a private or gated model](#using-a-private-or-gated-model) - - [A note on Shared Memory](#a-note-on-shared-memory-shm) - - [Distributed Tracing](#distributed-tracing) - - [Local Install](#local-install) - - [CUDA Kernels](#cuda-kernels) -- [Optimized architectures](#optimized-architectures) -- [Run Mistral](#run-a-model) - - [Run](#run) - - [Quantization](#quantization) -- [Develop](#develop) -- [Testing](#testing) - -Text Generation Inference (TGI) is a toolkit for deploying and serving Large Language Models (LLMs). TGI enables high-performance text generation for the most popular open-source LLMs, including Llama, Falcon, StarCoder, BLOOM, GPT-NeoX, and [more](https://huggingface.co/docs/text-generation-inference/supported_models). TGI implements many features, such as: - -- Simple launcher to serve most popular LLMs -- Production ready (distributed tracing with Open Telemetry, Prometheus metrics) -- Tensor Parallelism for faster inference on multiple GPUs -- Token streaming using Server-Sent Events (SSE) -- Continuous batching of incoming requests for increased total throughput -- Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures -- Quantization with : - - [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) - - [GPT-Q](https://arxiv.org/abs/2210.17323) - - [EETQ](https://github.com/NetEase-FuXi/EETQ) - - [AWQ](https://github.com/casper-hansen/AutoAWQ) -- [Safetensors](https://github.com/huggingface/safetensors) weight loading -- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) -- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor)) -- Stop sequences -- Log probabilities -- [Speculation](https://huggingface.co/docs/text-generation-inference/conceptual/speculation) ~2x latency -- [Guidance/JSON](https://huggingface.co/docs/text-generation-inference/conceptual/guidance). Specify output format to speed up inference and make sure the output is valid according to some specs.. -- Custom Prompt Generation: Easily generate text by providing custom prompts to guide the model's output -- Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance - -### Hardware support - -- [Nvidia](https://github.com/huggingface/text-generation-inference/pkgs/container/text-generation-inference) -- [AMD](https://github.com/huggingface/text-generation-inference/pkgs/container/text-generation-inference) (-rocm) -- [Inferentia](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference) -- [Intel GPU](https://github.com/huggingface/text-generation-inference/pull/1475) -- [Gaudi](https://github.com/huggingface/tgi-gaudi) -- [Google TPU](https://huggingface.co/docs/optimum-tpu/howto/serving) - - -## Get Started - -### Docker - -For a detailed starting guide, please see the [Quick Tour](https://huggingface.co/docs/text-generation-inference/quicktour). The easiest way of getting started is using the official Docker container: - -```shell -model=HuggingFaceH4/zephyr-7b-beta -volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run - -docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model -``` - -And then you can make requests like - -```bash -curl 127.0.0.1:8080/generate_stream \ - -X POST \ - -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ - -H 'Content-Type: application/json' -``` - -**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. - -**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0-rocm --model-id $model` instead of the command above. - -To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli): -``` -text-generation-launcher --help -``` - -### API documentation - -You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route. -The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference). - -### Using a private or gated model - -You have the option to utilize the `HUGGING_FACE_HUB_TOKEN` environment variable for configuring the token employed by -`text-generation-inference`. This allows you to gain access to protected resources. - -For example, if you want to serve the gated Llama V2 model variants: - -1. Go to https://huggingface.co/settings/tokens -2. Copy your cli READ token -3. Export `HUGGING_FACE_HUB_TOKEN=` - -or with Docker: - -```shell -model=meta-llama/Llama-2-7b-chat-hf -volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run -token= - -docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model -``` - -### A note on Shared Memory (shm) - -[`NCCL`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html) is a communication framework used by -`PyTorch` to do distributed training/inference. `text-generation-inference` make -use of `NCCL` to enable Tensor Parallelism to dramatically speed up inference for large language models. - -In order to share data between the different devices of a `NCCL` group, `NCCL` might fall back to using the host memory if -peer-to-peer using NVLink or PCI is not possible. - -To allow the container to use 1G of Shared Memory and support SHM sharing, we add `--shm-size 1g` on the above command. - -If you are running `text-generation-inference` inside `Kubernetes`. You can also add Shared Memory to the container by -creating a volume with: - -```yaml -- name: shm - emptyDir: - medium: Memory - sizeLimit: 1Gi -``` - -and mounting it to `/dev/shm`. - -Finally, you can also disable SHM sharing by using the `NCCL_SHM_DISABLE=1` environment variable. However, note that -this will impact performance. - -### Distributed Tracing - -`text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature -by setting the address to an OTLP collector with the `--otlp-endpoint` argument. - -### Architecture - -![TGI architecture](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/TGI.png) - -### Local install - -You can also opt to install `text-generation-inference` locally. - -First [install Rust](https://rustup.rs/) and create a Python virtual environment with at least -Python 3.9, e.g. using `conda`: - -```shell -curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh - -conda create -n text-generation-inference python=3.11 -conda activate text-generation-inference -``` - -You may also need to install Protoc. - -On Linux: - -```shell -PROTOC_ZIP=protoc-21.12-linux-x86_64.zip -curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP -sudo unzip -o $PROTOC_ZIP -d /usr/local bin/protoc -sudo unzip -o $PROTOC_ZIP -d /usr/local 'include/*' -rm -f $PROTOC_ZIP -``` - -On MacOS, using Homebrew: - -```shell -brew install protobuf -``` - -Then run: - -```shell -BUILD_EXTENSIONS=True make install # Install repository and HF/transformer fork with CUDA kernels -text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 -``` - -**Note:** on some machines, you may also need the OpenSSL libraries and gcc. On Linux machines, run: - -```shell -sudo apt-get install libssl-dev gcc -y -``` - -## Optimized architectures - -TGI works out of the box to serve optimized models for all modern models. They can be found in [this list](https://huggingface.co/docs/text-generation-inference/supported_models). - -Other architectures are supported on a best-effort basis using: - -`AutoModelForCausalLM.from_pretrained(, device_map="auto")` - -or - -`AutoModelForSeq2SeqLM.from_pretrained(, device_map="auto")` - - - -## Run locally - -### Run - -```shell -text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 -``` - -### Quantization - -You can also quantize the weights with bitsandbytes to reduce the VRAM requirement: - -```shell -text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 --quantize -``` - -4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`. - -## Develop - -```shell -make server-dev -make router-dev -``` - -## Testing - -```shell -# python -make python-server-tests -make python-client-tests -# or both server and client tests -make python-tests -# rust cargo tests -make rust-tests -# integration tests -make integration-tests -``` diff --git a/router/client/Cargo.toml b/backends/client/Cargo.toml similarity index 100% rename from router/client/Cargo.toml rename to backends/client/Cargo.toml diff --git a/router/client/build.rs b/backends/client/build.rs similarity index 100% rename from router/client/build.rs rename to backends/client/build.rs diff --git a/router/client/src/lib.rs b/backends/client/src/lib.rs similarity index 100% rename from router/client/src/lib.rs rename to backends/client/src/lib.rs diff --git a/router/client/src/v2/client.rs b/backends/client/src/v2/client.rs similarity index 100% rename from router/client/src/v2/client.rs rename to backends/client/src/v2/client.rs diff --git a/router/client/src/v2/mod.rs b/backends/client/src/v2/mod.rs similarity index 100% rename from router/client/src/v2/mod.rs rename to backends/client/src/v2/mod.rs diff --git a/router/client/src/v2/sharded_client.rs b/backends/client/src/v2/sharded_client.rs similarity index 100% rename from router/client/src/v2/sharded_client.rs rename to backends/client/src/v2/sharded_client.rs diff --git a/router/client/src/v3/client.rs b/backends/client/src/v3/client.rs similarity index 96% rename from router/client/src/v3/client.rs rename to backends/client/src/v3/client.rs index a996b14fae873c552458a9cc85cb78a0bfb6d541..d43f789e7ca95421266a068e267d65e000f309ea 100644 --- a/router/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -153,9 +153,13 @@ impl Client { }), // We truncate the input on the server side to be sure that it has the correct size truncate, + // Most request will have that + add_special_tokens: true, // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], slots: vec![], + cache_len: 0, + chunk_len: None, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, @@ -214,8 +218,13 @@ impl Client { pub async fn prefill( &mut self, batch: Batch, + cached_batch: Option, ) -> Result<(Vec, Option, PrefillTimings)> { - let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); + let request = tonic::Request::new(PrefillRequest { + batch: Some(batch), + cached_batch, + }) + .inject_context(); let response = self.stub.prefill(request).await?.into_inner(); Ok(( response.generations, diff --git a/router/client/src/v3/mod.rs b/backends/client/src/v3/mod.rs similarity index 100% rename from router/client/src/v3/mod.rs rename to backends/client/src/v3/mod.rs diff --git a/router/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs similarity index 96% rename from router/client/src/v3/sharded_client.rs rename to backends/client/src/v3/sharded_client.rs index ae8a899b38a6b5571e4b61cba157deb73b5ab8c3..854a5895ebab7474b2391dead5078b19e3a05370 100644 --- a/router/client/src/v3/sharded_client.rs +++ b/backends/client/src/v3/sharded_client.rs @@ -134,11 +134,12 @@ impl ShardedClient { pub async fn prefill( &mut self, batch: Batch, + cached_batch: Option, ) -> Result<(Vec, Option, PrefillTimings)> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.prefill(batch.clone()))) + .map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone()))) .collect(); #[allow(clippy::type_complexity)] let results: Result, Option, PrefillTimings)>> = @@ -221,6 +222,7 @@ impl Health for ShardedClient { chunks: vec![Chunk::Text("liveness".into()).into()], }), truncate: 10, + add_special_tokens: true, prefill_logprobs: false, parameters: Some(NextTokenChooserParameters { temperature: 1.0, @@ -244,6 +246,8 @@ impl Health for ShardedClient { // Block 0 is reserved for health checks blocks: vec![0], slots: (0..16).collect(), + cache_len: 0, + chunk_len: None, adapter_id: None, }; let batch = Batch { @@ -253,7 +257,7 @@ impl Health for ShardedClient { max_tokens: 2, max_blocks: 1, }; - self.clone().prefill(batch).await?; + self.clone().prefill(batch, None).await?; Ok(()) } } diff --git a/router/grpc-metadata/Cargo.toml b/backends/grpc-metadata/Cargo.toml similarity index 100% rename from router/grpc-metadata/Cargo.toml rename to backends/grpc-metadata/Cargo.toml diff --git a/router/grpc-metadata/src/lib.rs b/backends/grpc-metadata/src/lib.rs similarity index 100% rename from router/grpc-metadata/src/lib.rs rename to backends/grpc-metadata/src/lib.rs diff --git a/backends/trtllm/CMakeLists.txt b/backends/trtllm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..831372cdf99650041bcc51d2e438ccc2c7893a4f --- /dev/null +++ b/backends/trtllm/CMakeLists.txt @@ -0,0 +1,75 @@ +cmake_minimum_required(VERSION 3.20) + +if (NOT DEFINED CMAKE_CXX_COMPILER_LAUNCHER AND CMAKE_BUILD_TYPE STREQUAL "Debug") + find_program(CCACHE_EXECUTABLE "ccache") + if (CCACHE_EXECUTABLE) + message(STATUS "Using ccache") + set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_EXECUTABLE}" CACHE PATH "Path to ccache" FORCE) + endif () +endif () + +if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0") + cmake_policy(SET CMP0135 NEW) +endif () + +project(tgi-trtllm-backend VERSION 1.0.0) +set(CMAKE_CXX_STANDARD 20) + +include(FetchContent) +include(ExternalProject) + +option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF) +option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF) +set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support") +set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located") +set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located") +set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located") + +# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features +find_package(CUDAToolkit 12.6 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml) + +#### External dependencies #### +include(cmake/fmt.cmake) +include(cmake/json.cmake) +include(cmake/spdlog.cmake) +include(cmake/trtllm.cmake) + +# Let's build TRTLLM as part of CMake +add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..") + +# Tell CMake to need try to override the RPATH for executorWorker as it has not information on how to do so +set_target_properties(executorWorker PROPERTIES SKIP_BUILD_RPATH TRUE) + +# TGI TRTLLM Backend definition +add_library(tgi_trtllm_backend_impl STATIC include/backend.h lib/backend.cpp include/hardware.h) +include_directories(${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR}) +target_include_directories(tgi_trtllm_backend_impl PRIVATE + $ + $ +) +target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include") +target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper CUDA::cudart CUDA::nvml) +target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt) + +# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back +install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker) +install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB) + +#### Unit Tests #### +if (${TGI_TRTLLM_BACKEND_BUILD_TESTS}) + message(STATUS "Building tests") + FetchContent_Declare( + Catch2 + GIT_REPOSITORY https://github.com/catchorg/Catch2 + GIT_TAG v3.6.0 + ) + FetchContent_MakeAvailable(Catch2) + + # add_executable(tgi_trtllm_backend_tests tests/infer_test.cpp) + # target_link_libraries(tgi_trtllm_backend_tests PRIVATE tgi_trtllm_backend_impl Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt CUDA::cudart CUDA::nvml) + + list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras) + include(CTest) + include(Catch) + # catch_discover_tests(tgi_trtllm_backend_tests) +endif () diff --git a/backends/trtllm/Cargo.toml b/backends/trtllm/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..97ef1a768917160d744ad92a46e88a7a54d492a4 --- /dev/null +++ b/backends/trtllm/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "text-generation-backends-trtllm" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[dependencies] +async-trait = "0.1" +async-stream = "0.3" +clap = { version = "4.5", features = ["derive"] } +cxx = "1.0" +hashbrown = "0.14" +hf-hub = { workspace = true } +log = { version = "0.4", features = [] } +text-generation-router = { path = "../../router" } +tokenizers = { workspace = true } +tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokio-stream = "0.1.15" +thiserror = "1.0.63" +tracing = "0.1" +tracing-opentelemetry = "0.25" +tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } + +[build-dependencies] +cmake = "0.1" +cxx-build = { version = "1.0", features = ["parallel"] } +pkg-config = "0.3" diff --git a/backends/trtllm/README.md b/backends/trtllm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..94064504d74869c38b70165dc8a3bbf85eafa26d --- /dev/null +++ b/backends/trtllm/README.md @@ -0,0 +1,46 @@ +# Text Generation Inference - TensorRT-LLM Backend Implementation + +## Description + +This folder provides the sources of the TensorRT-LLM backend implementation powered by TensorRT-LLM Executor new API + +## Simplified Request Sequence + +```mermaid +sequenceDiagram + actor User + participant TextGenerationInference.HttpServer + participant TextGenerationInference.TensorRtLlmBackend + participant TextGenerationInference.TensorRtLlmWorkerThread + participant TensorRtLlm.Executor + participant Nvidia.Gpu + User ->> TextGenerationInference.HttpServer: POST /generate + TextGenerationInference.HttpServer ->> TextGenerationInference.TensorRtLlmBackend: Validate and forward inputs & parameters + TextGenerationInference.TensorRtLlmBackend ->> TextGenerationInference.TensorRtLlmWorkerThread: Allocate a new context and spawn a new thread to handle the request + TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Submit the request to the In-Flight Batcher + activate Nvidia.Gpu + TensorRtLlm.Executor ->> Nvidia.Gpu: Add the request to the poll for execution + TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Response with an unique request identifier + rect rgb(10, 92, 54) + loop every 100us + rect rgb(15, 81, 50) + alt Acquire lock to query executor + TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Poll request number of new token(s) generated + else There are new generated tokens + TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Retrieve newly generated tokens + TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Return decoded token information and potential error (omitted) + rect rgb(11, 110, 79) + alt Generated token is final + TensorRtLlm.Executor ->> Nvidia.Gpu: Remove request from the scheduler and from the GPU + TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream the remaining decoded tokens and flush the connection + else Generated token is not final + TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream token back to the user as they get decoded + end + end + end + end + deactivate Nvidia.Gpu + end + end + +``` diff --git a/backends/trtllm/build.rs b/backends/trtllm/build.rs new file mode 100644 index 0000000000000000000000000000000000000000..985019260b837035907b7e8af5eb20688e85301e --- /dev/null +++ b/backends/trtllm/build.rs @@ -0,0 +1,160 @@ +use cxx_build::CFG; +use pkg_config; +use std::env; +use std::env::consts::ARCH; +use std::path::{absolute, PathBuf}; + +const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"]; +const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST"); +const CUDA_REQUIRED_VERSION: &str = "12.6"; +const MPI_REQUIRED_VERSION: &str = "4.1"; +const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX"); +const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR"); +const NCCL_ROOT_DIR: Option<&str> = option_env!("NCCL_ROOT_DIR"); + +// Dependencies +const BACKEND_DEPS: [&str; 2] = ["tgi_trtllm_backend_impl", "tgi_trtllm_backend"]; +const CUDA_TRANSITIVE_DEPS: [&str; 4] = ["cuda", "cudart", "cublas", "nvidia-ml"]; +const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 5] = [ + ("dylib", "tensorrt_llm"), + ("static", "tensorrt_llm_executor_static"), + ("dylib", "tensorrt_llm_nvrtc_wrapper"), + ("dylib", "nvinfer_plugin_tensorrt_llm"), + ("dylib", "decoder_attention"), +]; + +macro_rules! probe { + ($name: expr, $version: expr) => { + if let Err(_) = pkg_config::probe_library($name) { + pkg_config::probe_library(&format!("{}-{}", $name, $version)) + .expect(&format!("Failed to locate {}", $name)); + } + }; +} + +fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) { + // Build the backend implementation through CMake + let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi"); + let tensorrt_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt"); + let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("75-real;80-real;86-real;89-real;90-real"); + + let mut install_path = PathBuf::from(install_path); + if !install_path.is_absolute() { + install_path = absolute(out_dir).expect("cannot happen").join(install_path); + } + + let _ = cmake::Config::new(".") + .uses_cxx11() + .generator("Ninja") + .profile(match is_debug { + true => "Debug", + false => "Release", + }) + .env("OPT_LEVEL", opt_level) + .define("CMAKE_INSTALL_PREFIX", &install_path) + .define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc") + .define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list) + .define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path) + .build(); + + // Additional transitive CMake dependencies + let deps_folder = out_dir.join("build").join("_deps"); + for dependency in ADDITIONAL_BACKEND_LINK_LIBRARIES { + let dep_name = match is_debug { + true => format!("{}d", dependency), + false => String::from(dependency), + }; + let dep_path = deps_folder.join(format!("{}-build", dependency)); + println!("cargo:rustc-link-search={}", dep_path.display()); + println!("cargo:rustc-link-lib=static={}", dep_name); + } + + // Emit linkage information from the artifacts we just built + let install_lib_path = install_path.join("lib"); + + println!( + r"cargo:warning=Adding link search path: {}", + install_lib_path.display() + ); + println!(r"cargo:rustc-link-search={}", install_lib_path.display()); + + (PathBuf::from(install_path), deps_folder) +} + +fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) { + let ndebug = match is_debug { + true => "1", + false => "0", + }; + + CFG.include_prefix = "backends/trtllm"; + cxx_build::bridge("src/lib.rs") + .static_flag(true) + .include(deps_folder.join("fmt-src").join("include")) + .include(deps_folder.join("spdlog-src").join("include")) + .include(deps_folder.join("json-src").join("include")) + .include(deps_folder.join("trtllm-src").join("cpp").join("include")) + .include("/usr/local/cuda/include") + .include("/usr/local/tensorrt/include") + .file("src/ffi.cpp") + .std("c++20") + .define("NDEBUG", ndebug) + .compile("tgi_trtllm_backend"); + + println!("cargo:rerun-if-changed=CMakeLists.txt"); + println!("cargo:rerun-if-changed=cmake/trtllm.cmake"); + println!("cargo:rerun-if-changed=cmake/json.cmake"); + println!("cargo:rerun-if-changed=cmake/fmt.cmake"); + println!("cargo:rerun-if-changed=cmake/spdlog.cmake"); + println!("cargo:rerun-if-changed=include/backend.h"); + println!("cargo:rerun-if-changed=lib/backend.cpp"); + println!("cargo:rerun-if-changed=include/ffi.h"); + println!("cargo:rerun-if-changed=src/ffi.cpp"); +} + +fn main() { + // Misc variables + let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); + let build_profile = env::var("PROFILE").unwrap(); + let (is_debug, opt_level) = match build_profile.as_ref() { + "debug" => (true, "0"), + _ => (false, "3"), + }; + + // Build the backend + let (_backend_path, deps_folder) = build_backend(is_debug, opt_level, &out_dir); + + // Build the FFI layer calling the backend above + build_ffi_layer(&deps_folder, is_debug); + + // Emit linkage search path + probe!("ompi", MPI_REQUIRED_VERSION); + + // Probe CUDA & co. with pkg-config + CUDA_TRANSITIVE_DEPS.iter().for_each(|name| { + probe!(name, CUDA_REQUIRED_VERSION); + }); + + // NCCL is slightly trickier because it might not have a pkgconfig installed + let nccl_library_path_default = format!("/usr/local/{}-linux-gnu", ARCH); + let nccl_library_path = NCCL_ROOT_DIR.unwrap_or(&nccl_library_path_default); + println!(r"cargo:rustc-link-search=native={}", nccl_library_path); + println!("cargo:rustc-link-lib=dylib=nccl"); + + // TensorRT + let tensort_library_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt/lib"); + println!(r"cargo:rustc-link-search=native={}", tensort_library_path); + println!("cargo:rustc-link-lib=dylib=nvinfer"); + + // TensorRT-LLM + TENSORRT_LLM_TRANSITIVE_DEPS + .iter() + .for_each(|(link_type, name)| { + println!("cargo:rustc-link-lib={}={}", link_type, name); + }); + + // Backend + BACKEND_DEPS.iter().for_each(|name| { + println!("cargo:rustc-link-lib=static={}", name); + }); +} diff --git a/backends/trtllm/cmake/fmt.cmake b/backends/trtllm/cmake/fmt.cmake new file mode 100644 index 0000000000000000000000000000000000000000..afd6ea5f0906e31a491d2113d6527cd9d6a9850e --- /dev/null +++ b/backends/trtllm/cmake/fmt.cmake @@ -0,0 +1,6 @@ +FetchContent_Declare( + fmt + DOWNLOAD_EXTRACT_TIMESTAMP + URL https://github.com/fmtlib/fmt/archive/refs/tags/11.0.2.tar.gz +) +FetchContent_MakeAvailable(fmt) diff --git a/backends/trtllm/cmake/json.cmake b/backends/trtllm/cmake/json.cmake new file mode 100644 index 0000000000000000000000000000000000000000..67eff2fe606708a4561b2f9ad3f19441c561a92d --- /dev/null +++ b/backends/trtllm/cmake/json.cmake @@ -0,0 +1,6 @@ +fetchcontent_declare( + json + DOWNLOAD_EXTRACT_TIMESTAMP + URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz +) +fetchcontent_makeavailable(json) diff --git a/backends/trtllm/cmake/spdlog.cmake b/backends/trtllm/cmake/spdlog.cmake new file mode 100644 index 0000000000000000000000000000000000000000..7f529a7d29e5ab39d9af9b233399f9cc3da1d9da --- /dev/null +++ b/backends/trtllm/cmake/spdlog.cmake @@ -0,0 +1,17 @@ +set(SPDLOG_USE_FMT ON) +set(SPDLOG_BUILD_SHARED OFF) +set(SPDLOG_FMT_EXTERNAL ON) + +# Define the level at which SPDLOG_ compilation level is defined +if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") + add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG) +else () + add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO) +endif () + +fetchcontent_declare( + spdlog + DOWNLOAD_EXTRACT_TIMESTAMP + URL https://github.com/gabime/spdlog/archive/refs/tags/v1.14.1.tar.gz +) +fetchcontent_makeavailable(spdlog) diff --git a/backends/trtllm/cmake/trtllm.cmake b/backends/trtllm/cmake/trtllm.cmake new file mode 100644 index 0000000000000000000000000000000000000000..5f1b6c19c01f25e31a5703d75f6f8774b1a7250b --- /dev/null +++ b/backends/trtllm/cmake/trtllm.cmake @@ -0,0 +1,43 @@ +set(TRT_INCLUDE_DIR ${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR}) +set(TRT_LIB_DIR ${TGI_TRTLLM_BACKEND_TRT_LIB_DIR}) + +set(USE_CXX11_ABI ON) +set(BUILD_PYT OFF) +set(BUILD_PYBIND OFF) +set(BUILD_MICRO_BENCHMARKS OFF) +set(BUILD_BENCHMARKS OFF) +set(BUILD_TESTS OFF) +set(CMAKE_CUDA_ARCHITECTURES ${TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST}) + +message(STATUS "Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}") + +if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") + set(FAST_BUILD ON) + set(NVTX_DISABLE OFF) +else () + set(FAST_BUILD OFF) + set(FAST_MATH ON) + set(NVTX_DISABLE ON) +endif () + +fetchcontent_declare( + trtllm + GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git + GIT_TAG 201135e58aa525af7e523d091d4c9584229524bc + GIT_SHALLOW FALSE + DOWNLOAD_EXTRACT_TIMESTAMP +) +fetchcontent_makeavailable(trtllm) + +message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}") +execute_process(COMMAND git lfs install WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/") +execute_process(COMMAND git lfs pull WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/") + +# TRTLLM use a JIT based *precompiled* library to generate some specific kernels, we are generating the path to this one here +set(TRTLLM_NVRTC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_nvrtc_wrapper${CMAKE_SHARED_LIBRARY_SUFFIX}" CACHE INTERNAL "nvrtc wrapper library name") +set(TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_NVRTC_LIBRARY_NAME}" + CACHE INTERNAL "nvrtc wrapper library path") + +# The same Executor Static library +set(TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_executor_static${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE INTERNAL "executor_static library name") +set(TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/executor/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME}" CACHE INTERNAL "executor_static library path") diff --git a/server/marlin/marlin_kernels/py.typed b/backends/trtllm/cmake/utils/detect_cuda_arch.cu similarity index 100% rename from server/marlin/marlin_kernels/py.typed rename to backends/trtllm/cmake/utils/detect_cuda_arch.cu diff --git a/backends/trtllm/include/backend.h b/backends/trtllm/include/backend.h new file mode 100644 index 0000000000000000000000000000000000000000..d23f6288964d74a6c71e11bc13b3b235a4d592fa --- /dev/null +++ b/backends/trtllm/include/backend.h @@ -0,0 +1,144 @@ +// +// Created by Morgan Funtowicz on 6/30/24. +// + +#ifndef TGI_TRTLLM_BACKEND_H +#define TGI_TRTLLM_BACKEND_H + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +using json = nlohmann::json; +namespace tle = tensorrt_llm::executor; + + +#define CAST_SIZETYPE(x) static_cast(x) + +namespace huggingface::tgi::backends { + using RequestId = tle::IdType; + using TokenId = tle::TokenIdType; + + const static auto OUTPUT_CONFIG = tle::OutputConfig(true, false, false, true, false); + constexpr auto FMT_NOT_ENOUGH_GPUS = FMT_STRING( + "Not enough GPUs to allocate requested model (detected: {:d}, required: {:d})"); + constexpr auto FMT_EXECUTOR_STATS = FMT_STRING( + "Submitting inference [{}] to the executor ({:d} already in-flight)"); + constexpr auto FMT_SAMPLING_CONFIG = FMT_STRING( + "Sampling: topK={:d}, topP={:.1f}, temperature={:.1f}, repetition_penalty={:.1f}, frequency_penalty={:.1f}, seed={:d}"); + + /** + * Initialize all the components required by TRTLLM. + * It is required to call this function before attempting to load any engine + */ + void InitializeBackend(); + + /** + * Initialize logging mechanism + */ + void InitializeLogging(); + + + /** + * + * @param config TensorRT-LLM configuration object + * @param workerPath Path to the "executorWorker" provided by TensorRT-LLM when using orchestrator mode + * @return + */ + tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath); + + /** + * + * @param worldSize + * @param workerPath + * @return + */ + tle::ParallelConfig GetParallelConfig(size_t worldSize, std::string workerPath) noexcept; + + /** + * Get the sampling configuration from the parameters provided by TGI + * @param topK + * @param topP + * @param temperature + * @param repetition_penalty + * @param frequency_penalty + * @param seed + * @return + */ + tle::SamplingConfig GetSamplingConfig( + uint32_t topK, + float_t topP, + float_t temperature, + float_t repetition_penalty, + float_t frequency_penalty, + uint64_t seed + ) noexcept; + + /** + * Attempt to retrieve the + * @param generationConfigPath + * @return + */ + std::optional>> + GetStopWordsFromConfig(const std::filesystem::path &generationConfigPath) noexcept; + + /** + * + */ + class TensorRtLlmBackend { + private: + const json config; + tle::Executor executor; + + /** Frequently accessed variables cached here **/ + uint32_t maxNumTokens; + std::list> stopWords; + + public: + explicit TensorRtLlmBackend( + const std::filesystem::path &engineFolder, + const std::filesystem::path &executorWorker + ); + + /** + * Query the executor for the number of token available for pulling + * @return + */ + [[nodiscard]] size_t NumResponsesReady() const; + + /** + * Submit a new generation task to the executor + * @param tokens + * @param topK + * @param topP + * @param temperature + * @param repetitionPenalty + * @param frequencyPenalty + * @param seed + * @return Request id related to this generation for reference + */ + [[nodiscard]] RequestId Submit( + const std::vector &tokens, + uint32_t maxNewTokens, + int32_t topK, + float_t topP, + float_t temperature, + float_t repetitionPenalty, + float_t frequencyPenalty, + uint64_t seed + ); + + [[nodiscard]] std::vector PullNewTokens(); + }; +} + + +#endif //TGI_TRTLLM_BACKEND_H diff --git a/backends/trtllm/include/ffi.h b/backends/trtllm/include/ffi.h new file mode 100644 index 0000000000000000000000000000000000000000..449bcd4d7398146867628825dfe182419b5666e6 --- /dev/null +++ b/backends/trtllm/include/ffi.h @@ -0,0 +1,75 @@ +// +// Created by mfuntowicz on 7/11/24. +// + +#ifndef TGI_TRTLLM_BACKEND_FFI_H +#define TGI_TRTLLM_BACKEND_FFI_H + +#include +#include +#include +#include "backend.h" + +namespace huggingface::tgi::backends { + class TensorRtLlmBackendImpl; +} + +// Template to support returning error from TllmException back to Rust in a Result<> +#include + +namespace rust::behavior { + template + static void trycatch(Try &&func, Fail &&fail) noexcept try { + func(); + } catch (tensorrt_llm::common::TllmException &e) { + fail(e.what()); + } +} + +#include "backends/trtllm/src/lib.rs.h" + +namespace huggingface::tgi::backends { + + class TensorRtLlmBackendImpl : public TensorRtLlmBackend { + public: + /*** + * + * @param engineFolder + * @param executorWorker + */ + TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker); + + /*** + * + * @param tokens + * @param maxNewTokens + * @param topK + * @param topP + * @param temperature + * @param repetition_penalty + * @param frequency_penalty + * @param seed + * @return + */ + [[nodiscard("returned request id should be used to refer to the request's generation result later on")]] + uint64_t + Submit(rust::Slice tokens, uint32_t maxNewTokens, + int32_t topK, float_t topP, float_t temperature, + float_t repetition_penalty, float_t frequency_penalty, uint64_t seed); + + /*** + * + * @return + */ + std::unique_ptr> PullTokens(); + }; + + /*** + * + * @param engineFolder + * @return + */ + std::unique_ptr CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker); +} + +#endif //TGI_TRTLLM_BACKEND_FFI_H diff --git a/backends/trtllm/include/hardware.h b/backends/trtllm/include/hardware.h new file mode 100644 index 0000000000000000000000000000000000000000..9633495f4fd54682da6508936b9cb853e221411d --- /dev/null +++ b/backends/trtllm/include/hardware.h @@ -0,0 +1,59 @@ +// +// Created by mfuntowicz on 7/23/24. +// + +#ifndef TGI_TRTLLM_BACKEND_HARDWARE_H +#define TGI_TRTLLM_BACKEND_HARDWARE_H + +#include +#include +#include +#include +#include + +namespace huggingface::hardware::cuda { + +#define AMPERE_SM_MAJOR 8 +#define HOPPER_SM_MAJOR 9 + + /** + * Store information about the version of the CUDA Compute Capabilities detected on the device + */ + struct CudaComputeCapabilities { + int32_t major; + int32_t minor; + + [[nodiscard]] constexpr bool IsPostAmpere() const { return major >= AMPERE_SM_MAJOR; } + + [[nodiscard]] constexpr bool IsPostHopper() const { return major >= HOPPER_SM_MAJOR; } + }; + + CudaComputeCapabilities GetCudaComputeCapabilities() { + // Get the compute capabilities of the current hardware + nvmlDevice_t device; + CudaComputeCapabilities capabilities{0, 0}; + if (nvmlDeviceGetHandleByIndex_v2(0, &device) == NVML_SUCCESS) { + SPDLOG_DEBUG("Successfully acquired nvmlDevice_t = 0"); + if (nvmlDeviceGetCudaComputeCapability(device, &capabilities.major, &capabilities.minor) == NVML_SUCCESS) { + SPDLOG_INFO("Detected sm_{:d}{:d} compute capabilities", capabilities.major, capabilities.minor); + } + } + + return capabilities; + } + + /** + * Return the number of GPU detected. If no GPU is detected, return size_t::max() + * @return + */ + std::optional GetNumDevices() { + uint32_t numGpus = 0; + if (nvmlDeviceGetCount_v2(&numGpus) == NVML_SUCCESS) { + return std::optional(numGpus); + } else { + return std::nullopt; + } + } +} + +#endif //TGI_TRTLLM_BACKEND_HARDWARE_H diff --git a/backends/trtllm/lib/backend.cpp b/backends/trtllm/lib/backend.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4dd41de0072866d818a3c901d813f0344a6af683 --- /dev/null +++ b/backends/trtllm/lib/backend.cpp @@ -0,0 +1,203 @@ +#include +#include + +#include +#include +#include + +#include "backend.h" +#include "hardware.h" + + +void huggingface::tgi::backends::InitializeLogging() { +#ifdef NDEBUG + if (const auto TRTLLM_LOG_LEVEL_CSTR = std::getenv("TRTLLM_LOG_LEVEL")) { + std::string log_level(TRTLLM_LOG_LEVEL_CSTR); + std::transform(log_level.begin(), log_level.end(), log_level.begin(), [](unsigned char c) { + return std::tolower(c); + }); + + if (log_level == "debug") + spdlog::set_level(spdlog::level::debug); + else + spdlog::set_level(spdlog::level::info); + } +#else + spdlog::set_level(spdlog::level::debug); +#endif +} + +void huggingface::tgi::backends::InitializeBackend() { + SPDLOG_INFO("Initializing Backend..."); + nvmlInit_v2(); + initTrtLlmPlugins(); + + InitializeLogging(); + + SPDLOG_INFO("Backend Executor Version: {}", tle::version()); + const auto numGpus = huggingface::hardware::cuda::GetNumDevices(); + if (numGpus.has_value()) { + SPDLOG_INFO("Detected {:d} Nvidia GPU(s)", numGpus.value()); + } else { + SPDLOG_WARN("Failed to detected Nvidia GPU(s) on the system"); + } +} + +[[nodiscard]] +tle::ParallelConfig +huggingface::tgi::backends::GetParallelConfig(const size_t worldSize, const std::string workerPath) noexcept { + auto mode = tle::CommunicationMode::kLEADER; + std::optional orchestratorConfig = std::nullopt; + + if (worldSize > 1) { + SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode"); + mode = tle::CommunicationMode::kORCHESTRATOR; + orchestratorConfig = std::make_optional(true, workerPath, nullptr, true); + } else { + SPDLOG_INFO("Detected single engine deployment, using leader mode"); + } + + return tle::ParallelConfig(tle::CommunicationType::kMPI, mode, std::nullopt, std::nullopt, orchestratorConfig); +} + +[[nodiscard]] +tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) { + tle::ExecutorConfig execConfig(/* maxBeamWidth = */ 1); + + // Retrieve the compute capabilities to enable some options at runtime + const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities(); + + // Single engine (TP = PP = 1) -> using leader mode (no MPI involved) + const auto worldSize = config["/pretrained_config/mapping/world_size"_json_pointer].get(); + execConfig.setParallelConfig(GetParallelConfig(worldSize, workerPath)); + + // Define some configuration variables + execConfig.setKvCacheConfig(tle::KvCacheConfig(true)); + execConfig.setEnableChunkedContext(computeCapabilities.IsPostAmpere()); + execConfig.setSchedulerConfig(tle::SchedulerConfig(tle::CapacitySchedulerPolicy::kMAX_UTILIZATION)); + return execConfig; +} + +tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig( + const uint32_t topK, + const float_t topP, + const float_t temperature, + const float_t repetition_penalty, + const float_t frequency_penalty, + const uint64_t seed) noexcept { + + return tle::SamplingConfig( + 1, // TGI only use a single beam + topK, + topP, + std::nullopt, + std::nullopt, + std::nullopt, + seed, + temperature, + temperature, + std::nullopt, + repetition_penalty, + std::nullopt, + frequency_penalty + ); +} + +std::optional>> +huggingface::tgi::backends::GetStopWordsFromConfig( + const std::filesystem::path &generationConfigPath) noexcept { + if (exists(generationConfigPath)) { + const auto generationConfig = json::parse(std::ifstream(generationConfigPath)); + if (const auto eosTokenIds = generationConfig["/eos_token_id"_json_pointer]; eosTokenIds.is_array()) { + SPDLOG_INFO(FMT_STRING("Found {:d} EOS tokens"), eosTokenIds.size()); + std::list> stopWords(eosTokenIds.size()); + + const auto to_single_token = [](const auto tokenIdObj) -> decltype(stopWords)::value_type { + return {tokenIdObj.template get()}; + }; + + std::transform(eosTokenIds.cbegin(), eosTokenIds.cend(), stopWords.begin(), to_single_token); + return stopWords; + } else { + SPDLOG_INFO("Invalid EOS tokens entry found (not an array)"); + } + } else { + SPDLOG_INFO("No EOS tokens found, generation_config.json doesn't exist"); + } + + return std::nullopt; +} + +huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend( + const std::filesystem::path &enginesFolder, + const std::filesystem::path &executorWorker +) : + config(json::parse(std::ifstream(enginesFolder / "config.json"))), + executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY, + GetExecutorConfig(config, executorWorker.string())) { + + SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get()); + + // Ensure we have enough GPUs on the system + const auto worldSize = config["/pretrained_config/mapping/world_size"_json_pointer].get(); + const auto numGpus = huggingface::hardware::cuda::GetNumDevices().value_or(0); + if (numGpus < worldSize) { + SPDLOG_CRITICAL(FMT_NOT_ENOUGH_GPUS, numGpus, worldSize); + // todo : raise exception to catch on rust side + } + + // Cache variables + maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get(); + + // Attempt to discover stopWords from the generation_config.json + const auto generationConfigPath = enginesFolder / "generation_config.json"; + stopWords = GetStopWordsFromConfig(generationConfigPath).value_or(std::list>()); +} + +[[nodiscard("Returned number of requests needs to be consumed")]] +size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const { +#ifdef NDEBUG + return executor.getNumResponsesReady(); +#else + const auto numResponses = executor.getNumResponsesReady(); + if (numResponses > 0) SPDLOG_INFO(FMT_STRING("Num responses ready: {:d}"), numResponses); + return numResponses; +#endif +} + +[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]] +tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit( + const std::vector &tokens, + const uint32_t maxNewTokens, + const int32_t topK, + const float_t topP, + const float_t temperature, + const float_t repetitionPenalty, + const float_t frequencyPenalty, + const uint64_t seed +) { + const auto maxNewTokensChecked = std::min(maxNewTokens, static_cast(maxNumTokens - tokens.size())); +#ifndef NDEBUG + { + const auto &iterations = executor.getLatestIterationStats(); + const auto &lastIteration = iterations.front(); + + SPDLOG_DEBUG(FMT_EXECUTOR_STATS, fmt::join(tokens, ", "), lastIteration.numActiveRequests); + SPDLOG_DEBUG(FMT_SAMPLING_CONFIG, topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed); + SPDLOG_DEBUG(FMT_STRING("Asking for max_new_tokens={:d}"), maxNewTokensChecked); + } +#endif + + const auto sampling = GetSamplingConfig(topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed); + + // Build the request + auto request = tle::Request{tokens, CAST_SIZETYPE(maxNewTokensChecked), true, sampling, OUTPUT_CONFIG}; + request.setStopWords(stopWords); + + // Submit to the executor for batching + return executor.enqueueRequest(request); +} + +std::vector huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() { + return executor.awaitResponses(); +} diff --git a/backends/trtllm/scripts/install_tensorrt.sh b/backends/trtllm/scripts/install_tensorrt.sh new file mode 100755 index 0000000000000000000000000000000000000000..4c2dc26b6bfbdf900edfbc178467b4ffddfd3bae --- /dev/null +++ b/backends/trtllm/scripts/install_tensorrt.sh @@ -0,0 +1,113 @@ +#!/bin/bash + +set -ex + +TRT_VER_BASE="10.4.0" +TRT_VER_FULL="${TRT_VER_BASE}.26" +CUDA_VER="12.6" +CUDNN_VER="9.5.0.50-1" +NCCL_VER="2.22.3-1+cuda12.6" +CUBLAS_VER="12.6.3.3-1" +NVRTC_VER="12.6.77-1" + +for i in "$@"; do + case $i in + --TRT_VER=?*) TRT_VER="${i#*=}";; + --CUDA_VER=?*) CUDA_VER="${i#*=}";; + --CUDNN_VER=?*) CUDNN_VER="${i#*=}";; + --NCCL_VER=?*) NCCL_VER="${i#*=}";; + --CUBLAS_VER=?*) CUBLAS_VER="${i#*=}";; + *) ;; + esac + shift +done + +NVCC_VERSION_OUTPUT=$(nvcc --version) +if [[ $(echo $NVCC_VERSION_OUTPUT | grep -oP "\d+\.\d+" | head -n 1) != ${CUDA_VER} ]]; then + echo "The version of pre-installed CUDA is not equal to ${CUDA_VER}." + exit 1 +fi + +install_ubuntu_requirements() { + apt-get update && apt-get install -y --no-install-recommends gnupg2 curl ca-certificates + ARCH=$(uname -m) + if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi + if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi + curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/${ARCH}/cuda-keyring_1.1-1_all.deb + dpkg -i cuda-keyring_1.1-1_all.deb + rm /etc/apt/sources.list.d/cuda-ubuntu2404-x86_64.list + + apt-get update + if [[ $(apt list --installed | grep libcudnn9) ]]; then + apt-get remove --purge -y --allow-change-held-packages libcudnn9* + fi + if [[ $(apt list --installed | grep libnccl) ]]; then + apt-get remove --purge -y --allow-change-held-packages libnccl* + fi + if [[ $(apt list --installed | grep libcublas) ]]; then + apt-get remove --purge -y --allow-change-held-packages libcublas* + fi + if [[ $(apt list --installed | grep cuda-nvrtc-dev) ]]; then + apt-get remove --purge -y --allow-change-held-packages cuda-nvrtc-dev* + fi + CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g') + apt-get install -y --no-install-recommends libcudnn9-cuda-12=${CUDNN_VER} libcudnn9-dev-cuda-12=${CUDNN_VER} + apt-get install -y --no-install-recommends libnccl2=${NCCL_VER} libnccl-dev=${NCCL_VER} + apt-get install -y --no-install-recommends libcublas-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER} libcublas-dev-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER} + # NVRTC static library doesn't exist in NGC PyTorch container. + NVRTC_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g') + apt-get install -y --no-install-recommends cuda-nvrtc-dev-${NVRTC_CUDA_VERSION}=${NVRTC_VER} + apt-get clean + rm -rf /var/lib/apt/lists/* +} + +install_centos_requirements() { + CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g') + yum -y update + yum -y install epel-release + yum remove -y libnccl* && yum -y install libnccl-${NCCL_VER} libnccl-devel-${NCCL_VER} + yum remove -y libcublas* && yum -y install libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER} libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER} + yum clean all +} + +install_tensorrt() { + #PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))') + #PARSED_PY_VERSION=$(echo "${PY_VERSION//./}") + TRT_CUDA_VERSION="12.6" + + if [ -z "$RELEASE_URL_TRT" ];then + ARCH=${TRT_TARGETARCH} + if [ -z "$ARCH" ];then ARCH=$(uname -m);fi + if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi + if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi + if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi + if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-24.04" && OS="ubuntu-24.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi + RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/${TRT_VER_BASE}/tars/TensorRT-${TRT_VER_FULL}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz + fi + wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar + tar -xf /tmp/TensorRT.tar -C /usr/local/ + mv /usr/local/TensorRT-${TRT_VER_FULL} /usr/local/tensorrt + # pip3 install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl + rm -rf /tmp/TensorRT.tar +} + +# Install base packages depending on the base OS +ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"') +case "$ID" in + debian) + install_ubuntu_requirements + install_tensorrt + ;; + ubuntu) + install_ubuntu_requirements + install_tensorrt + ;; + centos) + install_centos_requirements + install_tensorrt + ;; + *) + echo "Unable to determine OS..." + exit 1 + ;; +esac diff --git a/backends/trtllm/src/errors.rs b/backends/trtllm/src/errors.rs new file mode 100644 index 0000000000000000000000000000000000000000..812fd6e30d8b2ca2b0fad631c4fdf94d23e1b6b8 --- /dev/null +++ b/backends/trtllm/src/errors.rs @@ -0,0 +1,22 @@ +use std::path::PathBuf; +use thiserror::Error; + +use text_generation_router::server; + +#[derive(Debug, Error)] +pub enum TensorRtLlmBackendError { + #[error("Provided engine folder {0} doesn't exist")] + EngineFolderDoesntExists(PathBuf), + #[error("Provided executorWorker binary path {0} doesn't exist")] + ExecutorWorkerNotFound(PathBuf), + #[error("TensorRT-LLM Runtime error: {0}")] + Runtime(String), + #[error("Tokenizer error: {0}")] + Tokenizer(String), + #[error("Argument validation error: {0}")] + ArgumentValidation(String), + #[error("WebServer error: {0}")] + WebServer(#[from] server::WebServerError), + #[error("Tokio runtime failed to start: {0}")] + Tokio(#[from] std::io::Error), +} diff --git a/backends/trtllm/src/ffi.cpp b/backends/trtllm/src/ffi.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0a92c050f653f8c294f76795bb193d7428bd2633 --- /dev/null +++ b/backends/trtllm/src/ffi.cpp @@ -0,0 +1,89 @@ +// +// Created by mfuntowicz on 6/30/24. +// +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "backends/trtllm/include/ffi.h" + + +huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl( + const std::string_view &engineFolder, + const std::string_view &executorWorker +) : TensorRtLlmBackend(engineFolder, executorWorker) {} + + +uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit( + rust::Slice tokens, + uint32_t maxNewTokens, + int32_t topK, + float_t topP, + float_t temperature, + float_t repetition_penalty, + float_t frequency_penalty, + uint64_t seed) { + + // This will copy all the items from the initial slice + std::vector tokens_(tokens.begin(), tokens.end()); + return TensorRtLlmBackend::Submit( + std::move(tokens_), maxNewTokens, topK, topP, temperature, repetition_penalty, frequency_penalty, seed); +} + +std::unique_ptr> +huggingface::tgi::backends::TensorRtLlmBackendImpl::PullTokens() { + const auto responses = TensorRtLlmBackend::PullNewTokens(); + + auto steps = std::make_unique>(); + steps->reserve(responses.size()); + +#ifndef NDEBUG + SPDLOG_DEBUG(FMT_STRING("Pulled out {:d} new tokens"), responses->size()); +#endif + + // Transform tle::Response to GenerationStep + std::ranges::transform(responses.begin(), responses.end(), std::back_inserter(*steps), [](const tle::Response &r) { + const auto reqId = r.getRequestId(); + if (!r.hasError()) { + const auto result = r.getResult(); + return GenerationStep{ + reqId, + static_cast(result.outputTokenIds[0][0]), + result.logProbs.value()[0][0], + result.isFinal, + false, + std::string() + }; + } else { + return GenerationStep{ + reqId, + 0, + 0.0, + true, + true, + std::move(r.getErrorMsg()) + }; + } + }); + + return steps; +} + +std::unique_ptr +huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) { + SPDLOG_INFO("Creating TensorRT-LLM Backend"); + // Unconditionally call this to initialize and discover TRTLLM plugins + InitializeBackend(); + + const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end()); + const auto executorPath = std::string_view(executorWorker.begin(), executorWorker.end()); + return std::make_unique(std::move(enginePath), std::move(executorPath)); +} diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs new file mode 100644 index 0000000000000000000000000000000000000000..edd8caff154e94269ff921ef93cb4e721131a9ea --- /dev/null +++ b/backends/trtllm/src/lib.rs @@ -0,0 +1,68 @@ +pub use looper::TensorRtLlmBackendV2; + +pub mod errors; +mod looper; +mod utils; + +#[cxx::bridge(namespace = "huggingface::tgi::backends")] +mod ffi { + /// Struct used as shared type between rust and C++ to represent the result + /// of a single decoding iteration + #[derive(Debug, Clone)] + pub struct GenerationStep { + request_id: u64, + token_id: u32, + log_prob: f32, + is_final: bool, + has_error: bool, + error_msg: String, + } + + unsafe extern "C++" { + include!("backends/trtllm/src/ffi.cpp"); + + /// Represent an instance of the underlying TensorRT-LLM backend + type TensorRtLlmBackendImpl; + + /// Create an instance backed behind a std::unique_ptr to manage the lifespan of the backend + /// + /// # Arguments + /// + /// * `engine_folder`: Path to the folder containing all the TRTLLM engines + /// * `executor_worker`: Path to the TRTLLM executor worker + /// + /// returns: + /// + /// # Examples + /// + /// ``` + /// + /// ``` + #[rust_name = "create_tensorrt_llm_backend"] + fn CreateTensorRtLlmBackend( + engine_folder: &str, + executor_worker: &str, + ) -> Result>; + + #[rust_name = "num_responses_ready"] + fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize; + + #[rust_name = "submit"] + fn Submit( + self: Pin<&mut TensorRtLlmBackendImpl>, + tokens: &[u32], + max_new_tokens: u32, + top_k: i32, + top_p: f32, + temperature: f32, + repetition_penalty: f32, + frequency_penalty: f32, + seed: u64, + ) -> Result; + + #[rust_name = "pull_tokens"] + fn PullTokens( + self: Pin<&mut TensorRtLlmBackendImpl>, + ) -> Result>>; + } +} diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs new file mode 100644 index 0000000000000000000000000000000000000000..e26155c163c910f46349adbe82838345d37289ce --- /dev/null +++ b/backends/trtllm/src/looper.rs @@ -0,0 +1,382 @@ +use std::hint; +use std::ops::Deref; +use std::path::Path; + +use async_trait::async_trait; +use cxx::UniquePtr; +use hashbrown::HashMap; +use tokenizers::Tokenizer; +use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; +use tokio::sync::TryAcquireError; +use tokio::task::{spawn_blocking, JoinHandle}; +use tokio::time::Instant; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tracing::{debug, error, warn}; + +use text_generation_router::infer::InferError::{GenerationError, ValidationError}; +use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; +use text_generation_router::validation::ValidationError::{ + EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality, +}; +use text_generation_router::validation::{Chunk, ValidGenerateRequest}; +use text_generation_router::{FinishReason, Token}; + +use crate::errors::TensorRtLlmBackendError; +use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl}; +use crate::utils::first_line; + +type InferResult = Result; + +/// Wrap the requests along with the channel used to stream back to the client the decoded tokens +struct GenerationContext { + request: ValidGenerateRequest, + start: Option, + queued: Instant, + streamer: UnboundedSender>, +} + +#[derive(Debug, Copy, Clone)] +struct DecodedToken { + id: u32, + log_prob: f32, + is_final: bool, +} + +impl<'step> TryFrom<&'step GenerationStep> for DecodedToken { + type Error = InferError; + + fn try_from(step: &'step GenerationStep) -> Result { + if !step.has_error { + Ok(Self { + id: step.token_id, + log_prob: step.log_prob, + is_final: step.is_final, + }) + } else { + Err(GenerationError(step.error_msg.clone())) + } + } +} + +/// Wraps the decoded token with the channel used to stream back to the client the decoded tokens +struct DecodedTokenContext { + token: DecodedToken, + start: Option, + queued: Instant, + channel: UnboundedSender>, +} + +fn executor_status_looper( + mut backend: UniquePtr, + max_inflight_requests: usize, + mut waiting_requests: UnboundedReceiver, + post_processor_sender: UnboundedSender<(u64, InferResult)>, +) { + // Track the tuple (request_id, stream) for each request + let mut in_flights = + HashMap::::with_capacity(max_inflight_requests * 2); + + // TODO: Does it need a spin-loop? + 'scheduler: loop { + // Is there any request pending to be scheduled? + let awaiting_requests = waiting_requests.len(); + for _ in 0..awaiting_requests { + // Retrieve all the requests + if let Some(mut ctx) = waiting_requests.blocking_recv() { + // Submit all the request to the executor and move the context to the in-flight tracker + let request = &ctx.request; + let generation_params = &request.parameters; + let stopping_params = &request.stopping_parameters; + let input_ids = request.input_ids.as_deref(); + + // Submit to the TensorRT-LLM executor for scheduling + match backend.pin_mut().submit( + &input_ids.unwrap(), // This is checked beforehand in validate() + stopping_params.max_new_tokens, + generation_params.top_k as i32, + generation_params.top_p, + generation_params.temperature, + generation_params.repetition_penalty, + generation_params.frequency_penalty, + generation_params.seed, + ) { + Ok(request_id) => { + // Insert the context linked to the generated request id in the tracker + debug!("[in-flight] Added {}", request_id); + ctx.start = Some(Instant::now()); + in_flights.insert(request_id, ctx); + } + Err(e) => { + // Return to the caller + let what = e.to_string(); + error!(error = what.as_str(), "Failed to schedule request"); + + let err = Err(InferError::Overloaded(TryAcquireError::NoPermits)); + if let Err(_) = ctx.streamer.send(err) { + error!("Failed to send back error to the client"); + } + } + }; + } + } + + if backend.num_responses_ready() > 0 { + match backend.pin_mut().pull_tokens() { + Ok(responses) => { + // Iterate through all the decoded token + for step in responses.deref() { + if let Some(ctx) = in_flights.get(&step.request_id) { + // Remove from tracked requests + let parcel = + DecodedToken::try_from(step).map(|dt| DecodedTokenContext { + token: dt, + start: ctx.start, + queued: ctx.queued, + channel: ctx.streamer.clone(), + }); + + // Submit the work to p:the post_processor + let posted = post_processor_sender.send((step.request_id, parcel)); + + if posted.is_err() || step.is_final { + debug!("Removing {}", step.request_id); + let _ = in_flights.remove(&step.request_id); + } + } else { + warn!("Untracked request {}", step.request_id,); + } + } + } + Err(ref err) => { + error!("Failed to get responses from the executor: {}.", err.what()); + break 'scheduler; + } + } + } + + // Hint the CPU we are spin-locking + hint::spin_loop(); + } +} + +fn post_processor_looper( + tokenizer: Tokenizer, + max_inflight_requests: usize, + mut decoded_tokens: UnboundedReceiver<(u64, InferResult)>, +) { + let mut states: HashMap> = HashMap::with_capacity(max_inflight_requests * 2); + + 'post_processor: loop { + if decoded_tokens.is_closed() { + warn!("Post processor IPC is closed, loop will exit now."); + break 'post_processor; + } + + if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() { + match decoded { + Ok(ctx) => { + states + .entry(request_id) + .and_modify(|s| s.push(*&ctx.token.id)) + .or_insert_with(|| { + let mut state = Vec::with_capacity(MAX_NUM_TOKENS); + state.push(*&ctx.token.id); + state + }); + + let out = match tokenizer.decode(&[ctx.token.id], false) { + Ok(text) => { + let is_special = + tokenizer.get_added_vocabulary().is_special_token(&text); + let token = Token { + id: ctx.token.id, + text, + logprob: ctx.token.log_prob, + special: is_special, + }; + + let out = if !ctx.token.is_final { + InferStreamResponse::Intermediate { + token, + top_tokens: vec![], + } + } else { + let tokens = states.remove(&request_id).unwrap(); + let text = tokenizer.decode(&tokens, true); + let generated_text = GeneratedText { + text: text.unwrap(), + generated_tokens: tokens.len() as u32, + finish_reason: FinishReason::EndOfSequenceToken, + seed: None, + }; + + InferStreamResponse::End { + token, + top_tokens: vec![], + generated_text, + start: ctx.start.unwrap(), + queued: ctx.queued, + } + }; + + Ok(out) + } + Err(err) => Err(GenerationError(err.to_string())), + }; + + if let Err(_) = ctx.channel.send(out) { + warn!("Failed to send decoded token back to the user") + } + } + Err(_err) => { + todo!("what do we do?") + } + } + } + } +} + +fn ensure_paths_exist, PP: AsRef>( + engine_folder: P, + executor_worker_path: PP, +) -> Result<(String, String), TensorRtLlmBackendError> { + // Retrieve paths as &str for the backend creation + let engine_folder = engine_folder.as_ref(); + let executor_worker_path = executor_worker_path.as_ref(); + + // Ensure the engine folder exists + if !engine_folder.exists() { + let err = TensorRtLlmBackendError::EngineFolderDoesntExists(engine_folder.to_path_buf()); + + error!("Path validation failed: {}", err,); + return Err(err); + } + + // Ensure executor worker binary exists + if !executor_worker_path.exists() { + let err = TensorRtLlmBackendError::ExecutorWorkerNotFound(engine_folder.to_path_buf()); + + error!("Path validation failed: {}", err,); + return Err(err); + } + + let engine_folder = String::from( + engine_folder + .to_str() + .expect("Failed to convert engine_folder to valid UTF-8"), + ); + + let executor_worker_path = String::from( + executor_worker_path + .to_str() + .expect("Failed to convert executor_worker_path to valid UTF-8"), + ); + + Ok((engine_folder, executor_worker_path)) +} + +unsafe impl Send for TensorRtLlmBackendImpl {} + +pub struct TensorRtLlmBackendV2 { + executor_looper: JoinHandle<()>, + post_processor_looper: JoinHandle<()>, + executor: UnboundedSender, +} + +impl TensorRtLlmBackendV2 { + pub fn new + Send, PP: AsRef + Send>( + tokenizer: Tokenizer, + engine_folder: P, + executor_worker_path: PP, + max_inflight_requests: usize, + ) -> Result { + let (engine_folder, executor_worker_path) = + ensure_paths_exist(engine_folder, executor_worker_path)?; + + // Allocate the IPC layer to communicate with the backend + let (executor_sender, executor_receiver) = unbounded_channel(); + let (post_processor_sender, post_processor_receiver) = unbounded_channel(); + + // Create the FFI backend + let backend = create_tensorrt_llm_backend(&engine_folder, &executor_worker_path) + .map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?; + + // Executor looper is responsible for scheduling and pulling requests state at regular interval + let executor_looper = spawn_blocking(move || { + executor_status_looper( + backend, + max_inflight_requests, + executor_receiver, + post_processor_sender, + ) + }); + + // Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user + let post_processor_looper = spawn_blocking(move || { + post_processor_looper::<256>(tokenizer, max_inflight_requests, post_processor_receiver) + }); + + Ok(TensorRtLlmBackendV2 { + executor_looper, + post_processor_looper, + executor: executor_sender, + }) + } + + fn validate(request: &ValidGenerateRequest) -> InferResult<()> { + if request.input_ids.is_none() { + return Err(ValidationError(UnsupportedModality("No token provided"))); + } + + if request.top_n_tokens > 1 { + return Err(ValidationError(TopNTokensDisabled)); + } + + // TODO: Is it really needed? How can it be validated before? + if request.parameters.grammar.is_some() { + return Err(ValidationError(Grammar)); + } + + match request.inputs.len() { + 0 => Err(ValidationError(EmptyInput)), + 2.. => Err(GenerationError( + "TensorRT-LLM backend don't support multi-chunk".into(), + )), + 1 => match request.inputs.first().expect("Single item-chunk") { + Chunk::Text(_) => Ok(()), + Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))), + }, + } + } +} + +#[async_trait] +impl Backend for TensorRtLlmBackendV2 { + fn schedule( + &self, + inner: ValidGenerateRequest, + ) -> Result>, InferError> { + Self::validate(&inner)?; + + // Open-up the stream to send tokens + let (streamer, receiver) = unbounded_channel::>(); + + // Send the context to the executor for scheduling + let queued = Instant::now(); + match self.executor.send(GenerationContext { + request: inner, + start: None, + queued, + streamer, + }) { + Ok(_) => Ok(UnboundedReceiverStream::new(receiver)), + Err(_) => Err(GenerationError( + "Failed to submit request to the backend".into(), + )), + } + } + + async fn health(&self, _: bool) -> bool { + !self.executor_looper.is_finished() & !self.post_processor_looper.is_finished() + } +} diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs new file mode 100644 index 0000000000000000000000000000000000000000..6a247fc1d5265b8b7b240e13169948e9be34eb0b --- /dev/null +++ b/backends/trtllm/src/main.rs @@ -0,0 +1,302 @@ +use std::path::{Path, PathBuf}; + +use clap::Parser; +use hf_hub::api::tokio::{Api, ApiBuilder}; +use hf_hub::{Cache, Repo, RepoType}; +use tokenizers::Tokenizer; +use tracing::info; + +use text_generation_backends_trtllm::errors::TensorRtLlmBackendError; +use text_generation_backends_trtllm::TensorRtLlmBackendV2; +use text_generation_router::server::get_base_tokenizer; +use text_generation_router::usage_stats::UsageStatsLevel; +use text_generation_router::{server, HubTokenizerConfig}; + +/// App Configuration +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + #[clap(default_value = "128", long, env)] + max_concurrent_requests: usize, + #[clap(default_value = "2", long, env)] + max_best_of: usize, + #[clap(default_value = "4", long, env)] + max_stop_sequences: usize, + #[clap(default_value = "5", long, env)] + max_top_n_tokens: u32, + #[clap(default_value = "1024", long, env)] + max_input_tokens: usize, + #[clap(default_value = "2048", long, env)] + max_total_tokens: usize, + #[clap(default_value = "4096", long, env)] + max_batch_prefill_tokens: u32, + #[clap(long, env)] + max_batch_total_tokens: Option, + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, + #[clap(default_value = "3000", long, short, env)] + port: u16, + #[clap(long, env, required = true)] + tokenizer_name: String, + #[clap(long, env)] + tokenizer_config_path: Option, + #[clap(long, env)] + revision: Option, + #[clap(long, env)] + model_id: String, + #[clap(default_value = "2", long, env)] + validation_workers: usize, + #[clap(long, env)] + json_output: bool, + #[clap(long, env)] + otlp_endpoint: Option, + #[clap(default_value = "text-generation-inference.router", long, env)] + otlp_service_name: String, + #[clap(long, env)] + cors_allow_origin: Option>, + #[clap(default_value = "4", long, env)] + max_client_batch_size: usize, + #[clap(long, env)] + auth_token: Option, + #[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")] + executor_worker: PathBuf, + #[clap(default_value = "on", long, env)] + usage_stats: usage_stats::UsageStatsLevel, +} + +async fn get_tokenizer( + tokenizer_name: &str, + tokenizer_config_path: Option<&str>, + revision: Option<&str>, +) -> Option { + // Parse Huggingface hub token + let authorization_token = std::env::var("HF_TOKEN") + .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")) + .ok(); + + // Tokenizer instance + let local_path = Path::new(tokenizer_name); + + // Shared API builder initialization + let api_builder = || { + let mut builder = ApiBuilder::new() + .with_progress(false) + .with_token(authorization_token); + + if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") { + builder = builder.with_cache_dir(cache_dir.into()); + } + + builder + }; + + // Decide if we need to use the API based on the revision and local path + let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir(); + + // Initialize API if needed + #[derive(Clone)] + enum Type { + Api(Api), + Cache(Cache), + None, + } + let api = if use_api { + if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) { + let cache = std::env::var("HUGGINGFACE_HUB_CACHE") + .map_err(|_| ()) + .map(|cache_dir| Cache::new(cache_dir.into())) + .unwrap_or_else(|_| Cache::default()); + tracing::warn!("Offline mode active using cache defaults"); + Type::Cache(cache) + } else { + tracing::info!("Using the Hugging Face API"); + match api_builder().build() { + Ok(api) => Type::Api(api), + Err(_) => { + tracing::warn!("Unable to build the Hugging Face API"); + Type::None + } + } + } + } else { + Type::None + }; + + // Load tokenizer and model info + let ( + tokenizer_filename, + _config_filename, + tokenizer_config_filename, + _preprocessor_config_filename, + _processor_config_filename, + ) = match api { + Type::None => ( + Some(local_path.join("tokenizer.json")), + Some(local_path.join("config.json")), + Some(local_path.join("tokenizer_config.json")), + Some(local_path.join("preprocessor_config.json")), + Some(local_path.join("processor_config.json")), + ), + Type::Api(api) => { + let api_repo = api.repo(Repo::with_revision( + tokenizer_name.to_string(), + RepoType::Model, + revision.unwrap_or_else(|| "main").to_string(), + )); + + let tokenizer_filename = match api_repo.get("tokenizer.json").await { + Ok(tokenizer_filename) => Some(tokenizer_filename), + Err(_) => get_base_tokenizer(&api, &api_repo).await, + }; + let config_filename = api_repo.get("config.json").await.ok(); + let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); + let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok(); + let processor_config_filename = api_repo.get("processor_config.json").await.ok(); + + ( + tokenizer_filename, + config_filename, + tokenizer_config_filename, + preprocessor_config_filename, + processor_config_filename, + ) + } + Type::Cache(cache) => { + let repo = cache.repo(Repo::with_revision( + tokenizer_name.to_string(), + RepoType::Model, + revision.clone().unwrap_or_else(|| "main").to_string(), + )); + ( + repo.get("tokenizer.json"), + repo.get("config.json"), + repo.get("tokenizer_config.json"), + repo.get("preprocessor_config.json"), + repo.get("processor_config.json"), + ) + } + }; + + // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'. + let tokenizer_config: Option = if let Some(filename) = tokenizer_config_path + { + HubTokenizerConfig::from_file(filename) + } else { + tokenizer_config_filename.and_then(HubTokenizerConfig::from_file) + }; + + tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok()) +} + +#[tokio::main] +async fn main() -> Result<(), TensorRtLlmBackendError> { + // Get args + let args = Args::parse(); + // Pattern match configuration + let Args { + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + max_batch_prefill_tokens, + max_batch_total_tokens, + hostname, + port, + tokenizer_name, + tokenizer_config_path, + revision, + model_id, + validation_workers, + json_output, + otlp_endpoint, + otlp_service_name, + cors_allow_origin, + max_client_batch_size, + auth_token, + executor_worker, + usage_stats, + } = args; + + // Launch Tokio runtime + text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output); + + // Validate args + if max_input_tokens >= max_total_tokens { + return Err(TensorRtLlmBackendError::ArgumentValidation( + "`max_input_tokens` must be < `max_total_tokens`".to_string(), + )); + } + if max_input_tokens as u32 > max_batch_prefill_tokens { + return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); + } + + if validation_workers == 0 { + return Err(TensorRtLlmBackendError::ArgumentValidation( + "`validation_workers` must be > 0".to_string(), + )); + } + + if let Some(ref max_batch_total_tokens) = max_batch_total_tokens { + if max_batch_prefill_tokens > *max_batch_total_tokens { + return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); + } + if max_total_tokens as u32 > *max_batch_total_tokens { + return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); + } + } + + if !executor_worker.exists() { + return Err(TensorRtLlmBackendError::ArgumentValidation(format!( + "`executor_work` specified path doesn't exists: {}", + executor_worker.display() + ))); + } + + // Create the backend + let tokenizer = get_tokenizer( + &tokenizer_name, + tokenizer_config_path.as_deref(), + revision.as_deref(), + ) + .await + .expect("Failed to retrieve tokenizer implementation"); + + info!("Successfully retrieved tokenizer {}", &tokenizer_name); + let backend = TensorRtLlmBackendV2::new( + tokenizer, + model_id, + executor_worker, + max_concurrent_requests, + )?; + + info!("Successfully created backend"); + + // Run server + server::run( + backend, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + validation_workers, + auth_token, + tokenizer_name, + tokenizer_config_path, + revision, + hostname, + port, + cors_allow_origin, + false, + None, + None, + true, + max_client_batch_size, + usage_stats, + ) + .await?; + Ok(()) +} diff --git a/backends/trtllm/src/utils.rs b/backends/trtllm/src/utils.rs new file mode 100644 index 0000000000000000000000000000000000000000..4dedb0078632b43a5f3446fbb141a91cfc003867 --- /dev/null +++ b/backends/trtllm/src/utils.rs @@ -0,0 +1,22 @@ +/// +/// Extract the first line of the provided string reference. +/// If there is no lines in the buffer, it returns a string +/// which content is defined by the content of `fail` +/// # Arguments +/// +/// * `s`: The string buffer to extract the first-line from +/// * `fail`: A string content which is returned if no lines are +/// present in `s` +/// +/// returns: String +/// +/// # Examples +/// +/// ``` +/// let s = "My name is Morgan.\n I'm working at Hugging Face."; +/// first_line(s, "No line in string"); +/// ``` +#[inline] +pub(crate) fn first_line(s: &str, fail: &str) -> String { + s.lines().next().unwrap_or(fail).to_string() +} diff --git a/backends/trtllm/tests/infer_test.cpp b/backends/trtllm/tests/infer_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8520065a759e22cf87908450d4a5f3a9526e7809 --- /dev/null +++ b/backends/trtllm/tests/infer_test.cpp @@ -0,0 +1,14 @@ +// +// Created by mfuntowicz on 7/2/24. +// +#include +#include +#include "../include/backend.h" + +TEST_CASE("Load TRTLLM Engine on the TGI Backend", "[trtllm][engine][load]") { + const auto engines = std::filesystem::path("/home/mfuntowicz/.cache/huggingface/assets/trtllm/0.11.0.dev2024062500/meta-llama--Meta-Llama-3-8B-Instruct/4090/engines/"); + const auto executor = std::filesystem::path("/home/mfuntowicz/Workspace/text-generation-inference/backends/trtllm/cmake-build-debug/cmake-build-debug/_deps/trtllm-src/cpp/tensorrt_llm/executor_worker/executorWorker"); + + spdlog::info("Loading config from: {}", absolute(engines).string()); + huggingface::tgi::backends::TensorRtLlmBackend backend(engines, executor); +} diff --git a/backends/v2/Cargo.toml b/backends/v2/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..4d32474e77f65d7fe0913e8177a1a13e739bdfc6 --- /dev/null +++ b/backends/v2/Cargo.toml @@ -0,0 +1,75 @@ +[package] +name = "text-generation-router-v2" +description = "Text Generation Webserver" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[lib] +path = "src/lib.rs" + +[[bin]] +name = "text-generation-router-v2" +path = "src/main.rs" + +[dependencies] +async-trait = "0.1.74" +async-stream = "0.3.5" +axum = { version = "0.7", features = ["json"] } +axum-tracing-opentelemetry = "0.16" +text-generation-router = { path = "../../router" } +clap = { version = "4.4.5", features = ["derive", "env"] } +grpc-metadata = { path = "../grpc-metadata" } +futures = "0.3.28" +hf-hub = { workspace = true } +jsonschema = { version = "0.17.1", features = ["draft202012"] } +metrics = { workspace = true } +metrics-exporter-prometheus = { workspace = true } +nohash-hasher = "0.2.0" +opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } +opentelemetry-otlp = "0.13.0" +rand = "0.8.5" +reqwest = { version = "0.11.20", features = [] } +serde = "1.0.188" +serde_json = "1.0.107" +slotmap = "1.0.7" +thiserror = "1.0.48" +tokenizers = { workspace = true } +tokio = { version = "1.32.0", features = [ + "rt", + "rt-multi-thread", + "parking_lot", + "signal", + "sync", +] } +tokio-stream = "0.1.14" +tower-http = { version = "0.5.1", features = ["cors"] } +tracing = "0.1.37" +tracing-opentelemetry = "0.21.0" +tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } +utoipa = { version = "4.2.0", features = ["axum_extras"] } +utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } +init-tracing-opentelemetry = { version = "0.14.1", features = [ + "opentelemetry-otlp", +] } +minijinja = { workspace = true } +minijinja-contrib = { workspace = true } +futures-util = "0.3.30" +regex = "1.10.3" +once_cell = "1.19.0" +image = "0.25.1" +base64 = { workspace = true } +prost = "^0.12" +tonic = "^0.10" +tower = "^0.4" + +[build-dependencies] +tonic-build = "0.10.1" +prost-build = "0.12.1" + +[features] +default = ["ngrok"] +ngrok = ["text-generation-router/ngrok"] +google = ["text-generation-router/google"] +kserve = ["text-generation-router/kserve"] diff --git a/backends/v2/build.rs b/backends/v2/build.rs new file mode 100644 index 0000000000000000000000000000000000000000..f1d85dc70ed09d9d47f5890da215b841834e6824 --- /dev/null +++ b/backends/v2/build.rs @@ -0,0 +1,19 @@ +use std::fs; + +fn main() -> Result<(), Box> { + println!("cargo:rerun-if-changed=../../proto/"); + + fs::create_dir_all("src/client/pb").unwrap_or(()); + let mut config = prost_build::Config::new(); + config.protoc_arg("--experimental_allow_proto3_optional"); + + tonic_build::configure() + .build_client(true) + .build_server(false) + .out_dir("src/client/pb") + .include_file("mod.rs") + .compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"]) + .unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); + + Ok(()) +} diff --git a/backends/v2/src/backend.rs b/backends/v2/src/backend.rs new file mode 100644 index 0000000000000000000000000000000000000000..bc264138d28636abda4e5fe02ec956abf14fd8ba --- /dev/null +++ b/backends/v2/src/backend.rs @@ -0,0 +1,502 @@ +use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient}; +/// Batching and inference logic +use crate::queue::{Entry, Queue}; +use async_trait::async_trait; +use nohash_hasher::IntMap; +use std::sync::Arc; +use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; +use text_generation_router::validation::ValidGenerateRequest; +use text_generation_router::{FinishReason, PrefillToken, Token}; +use tokio::sync::mpsc::error::SendError; +use tokio::sync::{mpsc, Notify}; +use tokio::time::Instant; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tracing::{info_span, instrument, Instrument, Span}; + +pub struct BackendV2 { + /// Request queue + queue: Queue, + /// Notify batcher on queue appends + batching_task_notifier: Arc, + /// Client clone, used for health checks to skip the queue + client: ShardedClient, +} + +impl BackendV2 { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + client: ShardedClient, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: u32, + max_waiting_tokens: usize, + max_batch_size: Option, + requires_padding: bool, + window_size: Option, + speculate: u32, + ) -> Self { + // Infer shared state + let attention = std::env::var("ATTENTION").unwrap_or("paged".to_string()); + let block_size = match attention.as_str() { + "flashinfer" => 1, + "flashdecoding" => 256, + "paged" => 16, + _ => unreachable!(), + }; + + let queue = Queue::new(requires_padding, block_size, window_size, speculate); + let batching_task_notifier = Arc::new(Notify::new()); + + // Spawn batching background task that contains all the inference logic + tokio::spawn(batching_task( + client.clone(), + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + queue.clone(), + batching_task_notifier.clone(), + )); + + Self { + queue, + batching_task_notifier, + client, + } + } +} + +#[async_trait] +impl Backend for BackendV2 { + #[instrument(skip_all)] + fn schedule( + &self, + request: ValidGenerateRequest, + ) -> Result>, InferError> { + // MPSC channel to communicate with the background batching task + let (response_tx, response_rx) = mpsc::unbounded_channel(); + + // Append the request to the queue + self.queue.append(Entry { + request, + response_tx, + span: Span::current(), + temp_span: None, + queue_time: Instant::now(), + batch_time: None, + }); + + // Notify the background task that we have a new entry in the queue that needs + // to be batched + self.batching_task_notifier.notify_one(); + + // Return stream + Ok(UnboundedReceiverStream::new(response_rx)) + } + + async fn health(&self, current_health: bool) -> bool { + if current_health { + // Generation is healthy, we only check that the shards can allocate on device + self.client.device_health().await + } else { + self.client.model_health().await + } + .is_ok() + } +} + +/// Batching logic +/// Will be launched in a background Tokio task +/// +/// Batches requests and sends them to the inference server +#[allow(clippy::too_many_arguments)] +pub(crate) async fn batching_task( + mut client: ShardedClient, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: u32, + max_waiting_tokens: usize, + max_batch_size: Option, + queue: Queue, + notifier: Arc, +) { + // Infinite loop + loop { + // Wait for a notification from the Infer struct + notifier.notified().await; + + // Get the next batch from the queue + // This batch might be smaller than the maximum batch size if there are not enough requests + // waiting in the queue + while let Some((mut entries, batch, span)) = queue + .next_batch( + None, + max_batch_size, + max_batch_prefill_tokens, + max_batch_total_tokens, + ) + .await + { + let mut cached_batch = prefill(&mut client, batch, &mut entries) + .instrument(span) + .await; + let mut waiting_tokens = 1; + + // We loop until we do not receive any cached batch from the inference server (== until + // all requests have met their stopping criteria) + while let Some(batch) = cached_batch { + // Get current batch info + let batch_size = batch.size; + let batch_max_tokens = batch.max_tokens; + let mut batches = vec![batch]; + metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); + metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); + + let min_size = if waiting_tokens >= max_waiting_tokens { + // If we didn't onboard any new requests since >= max_waiting_tokens, we try + // to add a new batch even though its size might be small + None + } else { + // Minimum batch size + Some((batch_size as f32 * waiting_served_ratio).floor() as usize) + }; + + let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); + let max_size = + max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); + // Try to get a new batch + if let Some((mut new_entries, new_batch, span)) = queue + .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) + .await + { + // Tracking metrics + if min_size.is_some() { + metrics::counter!("tgi_batch_concat", "reason" => "backpressure") + .increment(1); + } else { + metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") + .increment(1); + } + + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span to add the info that this entry is waiting + // because a new batch is being computed + let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); + // Add relationships + span.follows_from(&entry_waiting_span); + entry_waiting_span.follows_from(&span); + // Update entry + entry.temp_span = Some(entry_waiting_span); + }); + + // Generate one token for this new batch to have the attention past in cache + let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) + .instrument(span) + .await; + // Reset waiting counter + waiting_tokens = 1; + // Extend current batch with the new batch + if let Some(new_cached_batch) = new_cached_batch { + entries.extend(new_entries); + batches.push(new_cached_batch); + } + } + + // Create span for this batch to add context to inference calls + let next_batch_size = entries.len(); + let next_batch_span = + info_span!(parent: None, "batch", batch_size = next_batch_size); + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span to link the batch back to this entry + let entry_batch_span = info_span!(parent: &entry.span, "infer"); + // Add relationships + next_batch_span.follows_from(&entry_batch_span); + entry_batch_span.follows_from(&next_batch_span); + // Update entry + entry.temp_span = Some(entry_batch_span); + }); + + cached_batch = decode(&mut client, batches, &mut entries) + .instrument(next_batch_span) + .await; + waiting_tokens += 1; + } + metrics::gauge!("tgi_batch_current_size").set(0.0); + metrics::gauge!("tgi_batch_current_max_tokens").set(0.0); + } + } +} + +#[instrument(skip_all)] +async fn prefill( + client: &mut ShardedClient, + batch: Batch, + entries: &mut IntMap, +) -> Option { + let start_time = Instant::now(); + let batch_id = batch.id; + metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); + + match client.prefill(batch).await { + Ok((generations, next_batch, timings)) => { + let start_filtering_time = Instant::now(); + // Send generated tokens and filter stopped entries + filter_send_generations(generations, entries); + + // Filter next batch and remove requests that were stopped + let next_batch = filter_batch(client, next_batch, entries).await; + + metrics::histogram!("tgi_batch_forward_duration","method" => "prefill") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration","method" => "prefill") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1); + next_batch + } + // If we have an error, we discard the whole batch + Err(err) => { + let _ = client.clear_cache(Some(batch_id)).await; + send_errors(err, entries); + metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1); + None + } + } +} + +#[instrument(skip_all)] +async fn decode( + client: &mut ShardedClient, + batches: Vec, + entries: &mut IntMap, +) -> Option { + let start_time = Instant::now(); + let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); + metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1); + + match client.decode(batches).await { + Ok((generations, next_batch, timings)) => { + let start_filtering_time = Instant::now(); + // Send generated tokens and filter stopped entries + filter_send_generations(generations, entries); + + // Filter next batch and remove requests that were stopped + let next_batch = filter_batch(client, next_batch, entries).await; + + if let Some(concat_duration) = timings.concat { + metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") + .record(concat_duration.as_secs_f64()); + } + metrics::histogram!("tgi_batch_forward_duration", "method" => "decode") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "decode") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "decode") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration", "method" => "decode") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1); + next_batch + } + // If we have an error, we discard the whole batch + Err(err) => { + for id in batch_ids { + let _ = client.clear_cache(Some(id)).await; + } + send_errors(err, entries); + metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1); + None + } + } +} + +/// Filter a `batch` and remove all requests not present in `entries` +#[instrument(skip_all)] +async fn filter_batch( + client: &mut ShardedClient, + next_batch: Option, + entries: &IntMap, +) -> Option { + let mut batch = next_batch?; + + // No need to filter + if batch.size as usize == entries.len() { + return Some(batch); + } + + let id = batch.id; + + // Retain only requests that are still in entries + batch.request_ids.retain(|id| entries.contains_key(id)); + + if batch.request_ids.is_empty() { + // All requests have been filtered out + // Next batch is now empty + // Clear it from the Python shards cache + // We unwrap here as we need to panic since we cannot recover if this method fails + client.clear_cache(Some(id)).await.unwrap(); + None + } else { + // Filter Python shard cache + // We unwrap here as we need to panic since we cannot recover if this method fails + client.filter_batch(id, batch.request_ids).await.unwrap() + } +} + +/// Send one or multiple `InferStreamResponse` to Infer for all `entries` +/// and filter entries +#[instrument(skip_all)] +fn filter_send_generations(generations: Vec, entries: &mut IntMap) { + generations.into_iter().for_each(|generation| { + let id = generation.request_id; + // Get entry + // We can `expect` here as the request id should always be in the entries + let entry = entries + .get(&id) + .expect("ID not found in entries. This is a bug."); + + // Create and enter a span to link this function back to the entry + let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered(); + // Send generation responses back to the infer task + // If the receive an error from the Flume channel, it means that the client dropped the + // request and we need to stop generating hence why we unwrap_or(true) + let stopped = send_responses(generation, entry).inspect_err(|_err| { + tracing::error!("Entry response channel error."); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); + }).unwrap_or(true); + if stopped { + entries.remove(&id).expect("ID not found in entries. This is a bug."); + } + }); +} + +/// Send responses through the `entry` response channel +fn send_responses( + generation: Generation, + entry: &Entry, +) -> Result>>> { + // Return directly if the channel is disconnected + if entry.response_tx.is_closed() { + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); + return Ok(true); + } + + let mut stopped = false; + + if let Some(prefill_tokens) = generation.prefill_tokens { + // Create Token objects + // We do that here instead of in the Python code as Rust for loops are faster + let prefill_tokens = prefill_tokens + .ids + .into_iter() + .zip(prefill_tokens.logprobs) + .zip(prefill_tokens.texts) + .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) + .collect(); + + // Send message + entry + .response_tx + .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; + } + + // Create last Token + let tokens_ = generation.tokens.expect("Non empty tokens in generation"); + let n = tokens_.ids.len(); + metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64); + let mut iterator = tokens_ + .ids + .into_iter() + .zip(tokens_.logprobs) + .zip(tokens_.texts) + .zip(tokens_.is_special) + .enumerate() + .peekable(); + while let Some((i, (((id, logprob), text), special))) = iterator.next() { + let token = Token { + id, + text, + logprob, + special, + }; + let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) { + top_tokens_ + .ids + .iter() + .zip(top_tokens_.logprobs.iter()) + .zip(top_tokens_.texts.iter()) + .zip(top_tokens_.is_special.iter()) + .map(|(((&id, &logprob), text), &special)| Token { + id, + text: text.to_string(), + logprob, + special, + }) + .collect() + } else { + vec![] + }; + match (&generation.generated_text, iterator.peek()) { + (Some(generated_text), None) => { + // Generation has ended + stopped = true; + // Send message + entry.response_tx.send(Ok(InferStreamResponse::End { + token, + top_tokens, + generated_text: GeneratedText::from(generated_text.clone()), + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + }))?; + } + _ => { + // Send message + entry + .response_tx + .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?; + } + } + } + + Ok(stopped) +} + +/// Send errors to Infer for all `entries` +#[instrument(skip_all)] +fn send_errors(error: ClientError, entries: &mut IntMap) { + entries.drain().for_each(|(_, entry)| { + // Create and enter a span to link this function back to the entry + let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); + let err = InferError::GenerationError(error.to_string()); + metrics::counter!("tgi_request_failure", "err" => "generation").increment(1); + tracing::error!("{err}"); + + // unwrap_or is valid here as we don't care if the receiver is gone. + entry + .response_tx + .send(Err(err)) + .unwrap_or(()); + }); +} + +impl From for GeneratedText { + fn from(value: crate::client::GeneratedText) -> Self { + let v2_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap(); + let finish_reason = match v2_finish_reason { + crate::client::FinishReason::Length => FinishReason::Length, + crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken, + crate::client::FinishReason::StopSequence => FinishReason::StopSequence, + }; + + Self { + text: value.text, + generated_tokens: value.generated_tokens, + finish_reason, + seed: value.seed, + } + } +} diff --git a/backends/v2/src/client/grpc_client.rs b/backends/v2/src/client/grpc_client.rs new file mode 100644 index 0000000000000000000000000000000000000000..b494352141cc1575854f5926ed70c17d6275dd45 --- /dev/null +++ b/backends/v2/src/client/grpc_client.rs @@ -0,0 +1,257 @@ +/// Single shard Client +use crate::client::pb; +use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64}; +use grpc_metadata::InjectTelemetryContext; +use pb::generate::v2::text_generation_service_client::TextGenerationServiceClient; +use pb::generate::v2::*; +use std::cmp::min; +use std::time::Duration; +use tonic::transport::{Channel, Uri}; +use tracing::instrument; + +/// Text Generation Inference gRPC client +#[derive(Debug, Clone)] +pub struct Client { + stub: TextGenerationServiceClient, +} + +impl Client { + /// Returns a client connected to the given url + #[allow(dead_code)] + pub async fn connect(uri: Uri) -> Result { + let channel = Channel::builder(uri).connect().await?; + + Ok(Self { + stub: TextGenerationServiceClient::new(channel), + }) + } + + /// Returns a client connected to the given unix socket + pub async fn connect_uds(path: String) -> Result { + let channel = Channel::from_shared("http://[::]:50051".to_string()) + .unwrap() + .connect_with_connector(tower::service_fn(move |_: Uri| { + tokio::net::UnixStream::connect(path.clone()) + })) + .await?; + + Ok(Self { + stub: TextGenerationServiceClient::new(channel), + }) + } + + /// Returns a list of uris or unix sockets of all shards + #[instrument(skip(self))] + pub async fn service_discovery(&mut self) -> Result> { + let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context(); + let response = self.stub.service_discovery(request).await.map_err(|_| { + ClientError::Connection("Server does not support v2 interface".to_string()) + })?; + let urls = response + .into_inner() + .urls + .into_iter() + // Remove unix socket prefix + .map(|url| match url.strip_prefix("unix://") { + None => url, + Some(stripped_url) => stripped_url.to_string(), + }) + .collect(); + Ok(urls) + } + + /// Get model info + #[instrument(skip(self))] + pub async fn info(&mut self) -> Result { + let request = tonic::Request::new(InfoRequest {}).inject_context(); + let response = self.stub.info(request).await?.into_inner(); + Ok(response) + } + + /// Get model health + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let request = tonic::Request::new(HealthRequest {}).inject_context(); + let response = self.stub.health(request).await?.into_inner(); + Ok(response) + } + + /// Clear the past generations cache + #[instrument(skip(self))] + pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { + let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context(); + self.stub.clear_cache(request).await?; + Ok(()) + } + + /// Filter a cached batch + #[instrument(skip(self))] + pub async fn filter_batch( + &mut self, + batch_id: u64, + request_ids: Vec, + ) -> Result> { + let request = tonic::Request::new(FilterBatchRequest { + batch_id, + request_ids, + }) + .inject_context(); + let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); + Ok(filtered_batch.batch) + } + + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip_all)] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + max_batch_size: Option, + ) -> Result> { + let mut n_tokens = 0; + let mut requests = Vec::new(); + // Create requests + while n_tokens < max_prefill_tokens { + let truncate = min(max_input_length, max_prefill_tokens - n_tokens); + + let mut inputs = String::new(); + inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); + if n_tokens == 0 { + // 1 request is enough to test vision heads. + // Sending images on other queries messes up easily with truncation. + inputs.push_str(&format!( + "![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})", + )); + } + + requests.push(Request { + id: 0, + inputs, + // We truncate the input on the server side to be sure that it has the correct size + truncate, + // Set sampling parameters to also take these ops into account in the max memory + parameters: Some(NextTokenChooserParameters { + temperature: 0.9, + top_k: 10, + top_p: 0.9, + typical_p: 0.9, + do_sample: false, + seed: 0, + repetition_penalty: 1.2, + frequency_penalty: 0.1, + watermark: true, + grammar: String::new(), + grammar_type: GrammarType::None as i32, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: max_total_tokens - truncate, + stop_sequences: vec![], + ignore_eos_token: true, + }), + prefill_logprobs: true, + top_n_tokens: 20, + }); + n_tokens += max_input_length; + + // Check max_batch_size + if Some(requests.len()) == max_batch_size { + break; + } + } + + let batch = Batch { + id: 0, + size: requests.len() as u32, + requests, + max_tokens: 0, + }; + + let request = tonic::Request::new(WarmupRequest { + batch: Some(batch), + max_input_length, + max_prefill_tokens, + max_total_tokens, + }) + .inject_context(); + let response = self.stub.warmup(request).await?.into_inner(); + Ok(response.max_supported_total_tokens) + } + + /// Generate one token for each request in the given batch + /// + /// Returns Generation for each request in batch + /// and the next cached batch + #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))] + pub async fn prefill( + &mut self, + batch: Batch, + ) -> Result<(Vec, Option, PrefillTimings)> { + let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); + let response = self.stub.prefill(request).await?.into_inner(); + Ok(( + response.generations, + response.batch, + PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns), + )) + } + + /// Generate one token for each request in the given cached batches + /// + /// Returns Generation for each request in batches + /// and the next cached batch + #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::()))] + pub async fn decode( + &mut self, + batches: Vec, + ) -> Result<(Vec, Option, DecodeTimings)> { + let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); + let response = self.stub.decode(request).await?.into_inner(); + Ok(( + response.generations, + response.batch, + DecodeTimings::new( + response.concat_ns, + response.forward_ns, + response.decode_ns, + response.total_ns, + ), + )) + } +} + +pub struct PrefillTimings { + pub forward: Duration, + pub decode: Duration, + pub total: Duration, +} + +impl PrefillTimings { + fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + Self { + forward: Duration::from_nanos(forward_ns), + decode: Duration::from_nanos(decode_ns), + total: Duration::from_nanos(total_ns), + } + } +} + +pub struct DecodeTimings { + pub concat: Option, + pub forward: Duration, + pub decode: Duration, + pub total: Duration, +} + +impl DecodeTimings { + fn new(concat_ns: Option, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + Self { + concat: concat_ns.map(Duration::from_nanos), + forward: Duration::from_nanos(forward_ns), + decode: Duration::from_nanos(decode_ns), + total: Duration::from_nanos(total_ns), + } + } +} diff --git a/backends/v2/src/client/mod.rs b/backends/v2/src/client/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..fa9d440645d61f2d3ba9291c7528b121c1f4fcfc --- /dev/null +++ b/backends/v2/src/client/mod.rs @@ -0,0 +1,68 @@ +//! Text Generation gRPC client library + +use async_trait::async_trait; +use thiserror::Error; +use tonic::transport; +use tonic::Status; + +#[allow(clippy::derive_partial_eq_without_eq)] +mod pb; + +mod grpc_client; +mod sharded_client; + +pub use grpc_client::Client; +pub use pb::generate::v2::{ + Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, HealthResponse, + InfoResponse, NextTokenChooserParameters, Request, StoppingCriteriaParameters, +}; +pub use sharded_client::ShardedClient; + +#[async_trait] +pub trait Health { + /// Check if a generate server is healthy by asking it to allocate a tensor on device + async fn device_health(&self) -> Result<()>; + + /// Check if a generate server is healthy by doing a forward pass. + /// EXPENSIVE + async fn model_health(&self) -> Result<()>; +} + +#[derive(Debug)] +pub struct ShardInfo { + pub requires_padding: bool, + pub dtype: String, + pub device_type: String, + pub window_size: Option, + pub speculate: u32, +} + +#[derive(Error, Debug, Clone)] +pub enum ClientError { + #[error("Could not connect to Text Generation server: {0}")] + Connection(String), + #[error("Server error: {0}")] + Generation(String), + #[error("Sharded results are empty")] + EmptyResults, +} + +impl From for ClientError { + fn from(err: Status) -> Self { + let err = Self::Generation(err.message().to_string()); + tracing::error!("{err}"); + err + } +} + +impl From for ClientError { + fn from(err: transport::Error) -> Self { + let err = Self::Connection(err.to_string()); + tracing::error!("{err}"); + err + } +} + +static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; + +pub type Result = std::result::Result; diff --git a/backends/v2/src/client/sharded_client.rs b/backends/v2/src/client/sharded_client.rs new file mode 100644 index 0000000000000000000000000000000000000000..eccf76d54670436a86d28cf80c275724f6f108a6 --- /dev/null +++ b/backends/v2/src/client/sharded_client.rs @@ -0,0 +1,252 @@ +/// Multi shard Client +use crate::client::{ClientError, Result}; +use crate::client::{Health, ShardInfo}; + +use crate::client::grpc_client::{DecodeTimings, PrefillTimings}; +use crate::client::InfoResponse; +use crate::client::{ + Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, +}; +use async_trait::async_trait; +use futures::future::join_all; +use tonic::transport::Uri; +use tracing::instrument; + +#[derive(Debug, Clone)] +/// Text Generation Inference gRPC multi client +pub struct ShardedClient { + clients: Vec, +} + +impl ShardedClient { + fn new(clients: Vec) -> Self { + Self { clients } + } + + /// Create a new ShardedClient from a master client. The master client will communicate with + /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method. + async fn from_master_client(mut master_client: Client) -> Result { + // Get all uris/unix sockets from the master client + let uris = master_client.service_discovery().await?; + let futures = uris.into_iter().map(Client::connect_uds); + let clients: Result> = join_all(futures).await.into_iter().collect(); + Ok(Self::new(clients?)) + } + + /// Returns a client connected to the given uri + #[allow(dead_code)] + pub async fn connect(uri: Uri) -> Result { + let master_client = Client::connect(uri).await?; + Self::from_master_client(master_client).await + } + + /// Returns a client connected to the given unix socket + pub async fn connect_uds(path: String) -> Result { + let master_client = Client::connect_uds(path).await?; + Self::from_master_client(master_client).await + } + + /// Get the model info + #[instrument(skip(self))] + pub async fn info(&mut self) -> Result { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.info()) + .collect(); + join_all(futures).await.pop().unwrap().map(ShardInfo::from) + } + + /// GRPC health check + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.health()) + .collect(); + join_all(futures).await.pop().unwrap() + } + + /// Clear the past generations cache + #[instrument(skip(self))] + pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.clear_cache(batch_id)) + .collect(); + join_all(futures).await.into_iter().collect() + } + + /// Filter a cached batch + #[instrument(skip(self))] + pub async fn filter_batch( + &mut self, + batch_id: u64, + request_ids: Vec, + ) -> Result> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone()))) + .collect(); + // all shards return the same message + join_all(futures).await.pop().unwrap() + } + + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip(self))] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + max_batch_size: Option, + ) -> Result> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| { + Box::pin(client.warmup( + max_input_length, + max_prefill_tokens, + max_total_tokens, + max_batch_size, + )) + }) + .collect(); + // Take the minimum value + let results = join_all(futures) + .await + .into_iter() + .collect::>>>()?; + Ok(results.into_iter().flatten().min()) + } + + /// Generate one token for each request in the given batch + /// + /// Returns Generation for each request in batch + /// and the next cached batch + #[instrument(skip_all, fields(id = & batch.id, size = & batch.size))] + pub async fn prefill( + &mut self, + batch: Batch, + ) -> Result<(Vec, Option, PrefillTimings)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.prefill(batch.clone()))) + .collect(); + #[allow(clippy::type_complexity)] + let results: Result, Option, PrefillTimings)>> = + join_all(futures).await.into_iter().collect(); + let mut results = results?; + + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; + + // Merge generations from different model shards + for (mut shard_generations, _, shard_timings) in results.into_iter() { + generations.append(&mut shard_generations); + // Return the timings of the slowest shard + if shard_timings.total > timings.total { + timings = shard_timings; + } + } + Ok((generations, next_batch, timings)) + } + + /// Generate one token for each request in the given cached batches + /// + /// Returns Generation for each request in batches + /// and the next cached batch + #[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))] + pub async fn decode( + &mut self, + batches: Vec, + ) -> Result<(Vec, Option, DecodeTimings)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.decode(batches.clone()))) + .collect(); + #[allow(clippy::type_complexity)] + let results: Result, Option, DecodeTimings)>> = + join_all(futures).await.into_iter().collect(); + let mut results = results?; + + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; + + // Merge generations from different model shards + for (mut shard_generations, _, shard_timings) in results.into_iter() { + generations.append(&mut shard_generations); + // Return the timings of the slowest shard + if shard_timings.total > timings.total { + timings = shard_timings; + } + } + Ok((generations, next_batch, timings)) + } +} + +impl From for ShardInfo { + fn from(value: InfoResponse) -> Self { + Self { + requires_padding: value.requires_padding, + dtype: value.dtype, + device_type: value.device_type, + window_size: value.window_size, + speculate: value.speculate, + } + } +} + +#[async_trait] +impl Health for ShardedClient { + async fn device_health(&self) -> Result<()> { + self.clone().health().await?; + Ok(()) + } + + async fn model_health(&self) -> Result<()> { + // Dummy batch of 1 token and 1 generated token + let liveness_request = Request { + id: u64::MAX, + inputs: "liveness".to_string(), + truncate: 10, + prefill_logprobs: false, + parameters: Some(NextTokenChooserParameters { + temperature: 1.0, + top_k: 0, + top_p: 1.0, + typical_p: 1.0, + do_sample: false, + seed: 0, + repetition_penalty: 1.0, + frequency_penalty: 0.0, + watermark: false, + grammar: String::new(), + grammar_type: GrammarType::None as i32, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: 1, + stop_sequences: vec![], + ignore_eos_token: false, + }), + top_n_tokens: 0, + }; + let batch = Batch { + id: u64::MAX, + requests: vec![liveness_request], + size: 1, + max_tokens: 2, + }; + self.clone().prefill(batch).await?; + Ok(()) + } +} diff --git a/backends/v2/src/lib.rs b/backends/v2/src/lib.rs new file mode 100644 index 0000000000000000000000000000000000000000..85c36931c5c027474692e7e6dd1480887e3e9966 --- /dev/null +++ b/backends/v2/src/lib.rs @@ -0,0 +1,141 @@ +mod backend; +mod client; +mod queue; + +use crate::client::{ClientError, ShardedClient}; +pub(crate) use backend::BackendV2; +use serde::Serialize; +use thiserror::Error; +use utoipa::ToSchema; + +#[derive(Clone, Debug, Serialize, ToSchema)] +pub struct BackendInfo { + /// Mandatory + #[schema(example = "cuda")] + pub model_device_type: String, + #[schema(example = "torch.float16")] + pub model_dtype: String, + + /// Backend parameters + #[schema(example = "1")] + pub speculate: usize, + #[schema(example = "1.2")] + pub waiting_served_ratio: f32, + #[schema(example = "32000")] + pub max_batch_total_tokens: u32, + #[schema(example = "20")] + pub max_waiting_tokens: usize, + #[schema(nullable = true, example = "null")] + pub max_batch_size: Option, +} + +#[allow(clippy::too_many_arguments)] +pub async fn connect_backend( + max_input_tokens: usize, + max_total_tokens: usize, + master_shard_uds_path: String, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: Option, + max_waiting_tokens: usize, + max_batch_size: Option, +) -> Result<(BackendV2, BackendInfo), V2Error> { + // Helper function + let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option| { + match max_supported_batch_total_tokens { + // Older models do not support automatic max-batch-total-tokens + None => { + let max_batch_total_tokens = max_batch_total_tokens + .unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))); + tracing::warn!("Model does not support automatic max batch total tokens"); + Ok(max_batch_total_tokens) + } + // Flash attention models return their max supported total tokens + Some(max_supported_batch_total_tokens) => { + // Warn if user added his own max-batch-total-tokens as we will ignore it + if max_batch_total_tokens.is_some() { + tracing::warn!( + "`--max-batch-total-tokens` is deprecated for Flash \ + Attention models." + ); + tracing::warn!( + "Inferred max batch total tokens: {max_supported_batch_total_tokens}" + ); + } + if max_total_tokens as u32 > max_supported_batch_total_tokens { + return Err(V2Error::NotEnoughMemory(max_total_tokens)); + } + + Ok(max_supported_batch_total_tokens) + } + } + }; + + let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) + .await + .map_err(V2Error::Connection)?; + + // server is running on v2 + // Clear the cache; useful if the webserver rebooted + sharded_client + .clear_cache(None) + .await + .map_err(V2Error::Cache)?; + // Get info from the shard + let shard_info = sharded_client.info().await.map_err(V2Error::Info)?; + + // Warmup model + tracing::info!("Warming up model"); + let max_batch_total_tokens = check_max_batch_total_tokens( + sharded_client + .warmup( + max_input_tokens as u32, + max_batch_prefill_tokens, + max_total_tokens as u32, + max_batch_size, + ) + .await + .map_err(V2Error::Warmup)?, + )?; + tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); + + let backend_info = BackendInfo { + waiting_served_ratio, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + model_device_type: shard_info.device_type.clone(), + model_dtype: shard_info.dtype.clone(), + speculate: shard_info.speculate as usize, + }; + + let backend = BackendV2::new( + sharded_client, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + shard_info.requires_padding, + shard_info.window_size, + shard_info.speculate, + ); + + tracing::info!("Using backend V3"); + + Ok((backend, backend_info)) +} + +#[derive(Debug, Error)] +pub enum V2Error { + #[error("Unable to clear the Python model shards cache: {0}")] + Cache(ClientError), + #[error("Unable to connect to the Python model shards: {0}")] + Connection(ClientError), + #[error("Unable to get the Python model shards info: {0}")] + Info(ClientError), + #[error("Unable to warmup the Python model shards: {0}")] + Warmup(ClientError), + #[error("Not enough memory to handle `max_total_tokens={0}`")] + NotEnoughMemory(usize), +} diff --git a/backends/v2/src/main.rs b/backends/v2/src/main.rs new file mode 100644 index 0000000000000000000000000000000000000000..ab4b7ce1d97ca1f518b224528719928b9dce1812 --- /dev/null +++ b/backends/v2/src/main.rs @@ -0,0 +1,212 @@ +use clap::{Parser, Subcommand}; +use text_generation_router::{server, usage_stats}; +use text_generation_router_v2::{connect_backend, V2Error}; +use thiserror::Error; + +/// App Configuration +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + #[command(subcommand)] + command: Option, + + #[clap(default_value = "128", long, env)] + max_concurrent_requests: usize, + #[clap(default_value = "2", long, env)] + max_best_of: usize, + #[clap(default_value = "4", long, env)] + max_stop_sequences: usize, + #[clap(default_value = "5", long, env)] + max_top_n_tokens: u32, + #[clap(default_value = "1024", long, env)] + max_input_tokens: usize, + #[clap(default_value = "2048", long, env)] + max_total_tokens: usize, + #[clap(default_value = "1.2", long, env)] + waiting_served_ratio: f32, + #[clap(default_value = "4096", long, env)] + max_batch_prefill_tokens: u32, + #[clap(long, env)] + max_batch_total_tokens: Option, + #[clap(default_value = "20", long, env)] + max_waiting_tokens: usize, + #[clap(long, env)] + max_batch_size: Option, + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, + #[clap(default_value = "3000", long, short, env)] + port: u16, + #[clap(default_value = "/tmp/text-generation-server-0", long, env)] + master_shard_uds_path: String, + #[clap(default_value = "bigscience/bloom", long, env)] + tokenizer_name: String, + #[clap(long, env)] + tokenizer_config_path: Option, + #[clap(long, env)] + revision: Option, + #[clap(long, env, value_enum)] + trust_remote_code: bool, + #[clap(default_value = "2", long, env)] + validation_workers: usize, + #[clap(long, env)] + api_key: Option, + #[clap(long, env)] + json_output: bool, + #[clap(long, env)] + otlp_endpoint: Option, + #[clap(default_value = "text-generation-inference.router", long, env)] + otlp_service_name: String, + #[clap(long, env)] + cors_allow_origin: Option>, + #[clap(long, env)] + ngrok: bool, + #[clap(long, env)] + ngrok_authtoken: Option, + #[clap(long, env)] + ngrok_edge: Option, + #[clap(long, env, default_value_t = false)] + disable_grammar_support: bool, + #[clap(default_value = "4", long, env)] + max_client_batch_size: usize, + #[clap(default_value = "on", long, env)] + usage_stats: usage_stats::UsageStatsLevel, +} + +#[derive(Debug, Subcommand)] +enum Commands { + PrintSchema, +} + +#[tokio::main] +async fn main() -> Result<(), RouterError> { + // Get args + let args = Args::parse(); + // Pattern match configuration + let Args { + command, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + hostname, + port, + master_shard_uds_path, + tokenizer_name, + tokenizer_config_path, + revision, + trust_remote_code, + validation_workers, + api_key, + json_output, + otlp_endpoint, + otlp_service_name, + cors_allow_origin, + ngrok, + ngrok_authtoken, + ngrok_edge, + disable_grammar_support, + max_client_batch_size, + usage_stats, + } = args; + + if let Some(Commands::PrintSchema) = command { + use utoipa::OpenApi; + let api_doc = text_generation_router::server::ApiDoc::openapi(); + let api_doc = serde_json::to_string_pretty(&api_doc).unwrap(); + println!("{}", api_doc); + std::process::exit(0); + }; + text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output); + + // Validate args + if max_input_tokens >= max_total_tokens { + return Err(RouterError::ArgumentValidation( + "`max_input_tokens` must be < `max_total_tokens`".to_string(), + )); + } + if max_input_tokens as u32 > max_batch_prefill_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); + } + + if validation_workers == 0 { + return Err(RouterError::ArgumentValidation( + "`validation_workers` must be > 0".to_string(), + )); + } + + if let Some(ref max_batch_total_tokens) = max_batch_total_tokens { + if max_batch_prefill_tokens > *max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); + } + if max_total_tokens as u32 > *max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); + } + } + + if let Some(max_batch_size) = max_batch_size { + if max_batch_size == 0 { + return Err(RouterError::ArgumentValidation( + "`max_batch_size` must be > 0".to_string(), + )); + } + } + + let (backend, _backend_info) = connect_backend( + max_input_tokens, + max_total_tokens, + master_shard_uds_path, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + ) + .await?; + + // Run server + server::run( + backend, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + validation_workers, + api_key, + tokenizer_name, + tokenizer_config_path, + revision, + trust_remote_code, + hostname, + port, + cors_allow_origin, + ngrok, + ngrok_authtoken, + ngrok_edge, + disable_grammar_support, + max_client_batch_size, + usage_stats, + ) + .await?; + Ok(()) +} + +#[derive(Debug, Error)] +enum RouterError { + #[error("Argument validation error: {0}")] + ArgumentValidation(String), + #[error("Backend failed: {0}")] + Backend(#[from] V2Error), + #[error("WebServer error: {0}")] + WebServer(#[from] server::WebServerError), + #[error("Tokio runtime failed to start: {0}")] + Tokio(#[from] std::io::Error), +} diff --git a/router/src/infer/v2/queue.rs b/backends/v2/src/queue.rs similarity index 95% rename from router/src/infer/v2/queue.rs rename to backends/v2/src/queue.rs index 93cf94699d82e1a5af18bc665209d09601b70974..bf52900f497fce3f28e4e0475552c3e7b18af3e5 100644 --- a/router/src/infer/v2/queue.rs +++ b/backends/v2/src/queue.rs @@ -1,14 +1,14 @@ -use crate::infer::{InferError, InferStreamResponse}; -use crate::validation::{ - ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, +use crate::client::{ + Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::min; use std::collections::VecDeque; -use text_generation_client::v2::{ - Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, +use text_generation_router::infer::InferError; +use text_generation_router::infer::InferStreamResponse; +use text_generation_router::validation::{ + ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, }; -use text_generation_client::ChunksToString; use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; use tracing::{info_span, instrument, Span}; @@ -111,7 +111,7 @@ async fn queue_task( match cmd { QueueCommand::Append(entry, span) => { span.in_scope(|| state.append(*entry)); - metrics::increment_gauge!("tgi_queue_size", 1.0); + metrics::gauge!("tgi_queue_size").increment(1.0); } QueueCommand::NextBatch { min_size, @@ -124,7 +124,7 @@ async fn queue_task( let next_batch = state.next_batch(min_size, max_size, prefill_token_budget, token_budget); response_sender.send(next_batch).unwrap(); - metrics::gauge!("tgi_queue_size", state.entries.len() as f64); + metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); }), } } @@ -205,13 +205,20 @@ impl State { } } + if let Some(max_size) = max_size { + if max_size == 0 { + tracing::debug!("No capacity"); + return None; + } + } + // Pad prefill_token_budget to be a multiple of block size let prefill_token_budget = ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size; // Create span for this batch to add context to inference calls let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); - next_batch_span.follows_from(&Span::current()); + next_batch_span.follows_from(Span::current()); let mut batch_requests = Vec::with_capacity(self.entries.len()); let mut batch_entries = @@ -226,7 +233,7 @@ impl State { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) if entry.response_tx.is_closed() { - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); tracing::debug!("Dropping entry"); continue; } @@ -336,7 +343,7 @@ impl State { // Increment batch id self.next_batch_id += 1; - metrics::histogram!("tgi_batch_next_size", batch.size as f64); + metrics::histogram!("tgi_batch_next_size").record(batch.size as f64); Some((batch_entries, batch, next_batch_span)) } @@ -397,6 +404,7 @@ impl From for StoppingCriteriaParameters { #[cfg(test)] mod tests { use super::*; + use std::sync::Arc; use tracing::info_span; fn default_entry() -> ( @@ -408,7 +416,9 @@ mod tests { let entry = Entry { request: ValidGenerateRequest { inputs: vec![], + input_ids: Some(Arc::new(vec![])), input_length: 0, + add_special_tokens: true, truncate: 0, decoder_input_details: false, parameters: ValidParameters { diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..69dad072fac03715f024f4b345cefffaa8d652c3 --- /dev/null +++ b/backends/v3/Cargo.toml @@ -0,0 +1,83 @@ +[package] +name = "text-generation-router-v3" +description = "Text Generation Webserver" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[lib] +path = "src/lib.rs" + +[[bin]] +name = "text-generation-router" +path = "src/main.rs" + +[dependencies] +async-trait = "0.1.74" +async-stream = "0.3.5" +axum = { version = "0.7", features = ["json"] } +axum-tracing-opentelemetry = "0.16" +text-generation-router = { path = "../../router" } +clap = { version = "4.4.5", features = ["derive", "env"] } +grpc-metadata = { path = "../grpc-metadata" } +futures = "0.3.28" +hf-hub = { workspace = true } +jsonschema = { version = "0.17.1", features = ["draft202012"] } +metrics = { workspace = true } +metrics-exporter-prometheus = { workspace = true } +nohash-hasher = "0.2.0" +opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } +opentelemetry-otlp = "0.13.0" +rand = "0.8.5" +reqwest = { version = "0.11.20", features = [] } +serde = "1.0.188" +serde_json = "1.0.107" +slotmap = "1.0.7" +thiserror = "1.0.48" +tokenizers = { workspace = true } +tokio = { version = "1.32.0", features = [ + "rt", + "rt-multi-thread", + "parking_lot", + "signal", + "sync", +] } +tokio-stream = "0.1.14" +tower-http = { version = "0.5.1", features = ["cors"] } +tracing = "0.1.37" +tracing-opentelemetry = "0.21.0" +tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } +utoipa = { version = "4.2.0", features = ["axum_extras"] } +utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } +init-tracing-opentelemetry = { version = "0.14.1", features = [ + "opentelemetry-otlp", +] } +minijinja = { workspace = true } +minijinja-contrib = { workspace = true } +futures-util = "0.3.30" +regex = "1.10.3" +once_cell = "1.19.0" +image = "0.25.1" +base64 = { workspace = true } +prost = "^0.12" +tonic = "^0.10" +tower = "^0.4" + +[build-dependencies] +tonic-build = "0.10.1" +prost-build = "0.12.1" + +[dev-dependencies] +criterion = "0.3" +itertools = "0.13" + +[features] +default = ["ngrok"] +ngrok = ["text-generation-router/ngrok"] +google = ["text-generation-router/google"] +kserve = ["text-generation-router/kserve"] + +[[bench]] +name = "prefix_cache" +harness = false diff --git a/backends/v3/benches/prefix_cache.rs b/backends/v3/benches/prefix_cache.rs new file mode 100644 index 0000000000000000000000000000000000000000..d9df45b231bf84a81d483f25e43171316906094b --- /dev/null +++ b/backends/v3/benches/prefix_cache.rs @@ -0,0 +1,47 @@ +use std::sync::Arc; + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rand::Rng; + +use text_generation_router_v3::block_allocator::Allocator; +use text_generation_router_v3::radix::RadixAllocator; + +fn prefix_cache_benchmark(c: &mut Criterion) { + // let prefixes: Vec> = (0..8192) + // .chunks(256) + // .into_iter() + // .map(|c| c.collect()) + // .collect(); + + let mut cache = RadixAllocator::new(1, 262144, None); + + c.bench_function("Radix allocator", |b| { + b.iter_batched( + || { + //prefixes + // .choose_multiple(&mut rand::thread_rng(), 5) + // .fold(Vec::new(), |mut v, s| { + // v.extend(s); + // v + // }) + + (0..7936) + .map(|_| rand::thread_rng().gen_range(0..1024)) + .collect::>() + }, + |prefill| { + let alloc = cache.allocate( + prefill.len() as u32 + 13, + Some(Arc::new(black_box(prefill))), + ); + if let Some(alloc) = alloc { + cache.free(alloc.blocks.clone(), alloc.allocation_id); + } + }, + criterion::BatchSize::SmallInput, + ); + }); +} + +criterion_group!(benches, prefix_cache_benchmark); +criterion_main!(benches); diff --git a/backends/v3/build.rs b/backends/v3/build.rs new file mode 100644 index 0000000000000000000000000000000000000000..6d702d14423bd770000a69af9cb9cce7be7fd828 --- /dev/null +++ b/backends/v3/build.rs @@ -0,0 +1,19 @@ +use std::fs; + +fn main() -> Result<(), Box> { + println!("cargo:rerun-if-changed=../../proto/"); + + fs::create_dir_all("src/client/pb").unwrap_or(()); + let mut config = prost_build::Config::new(); + config.protoc_arg("--experimental_allow_proto3_optional"); + + tonic_build::configure() + .build_client(true) + .build_server(false) + .out_dir("src/client/pb") + .include_file("mod.rs") + .compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"]) + .unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); + + Ok(()) +} diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs new file mode 100644 index 0000000000000000000000000000000000000000..a5c0f5125b25ca8b1d3e284e9efccb82b310b706 --- /dev/null +++ b/backends/v3/src/backend.rs @@ -0,0 +1,548 @@ +/// Batching and inference logic +use crate::client::{ + Batch, CachedBatch, ClientError, Generation, Health, InfoResponse, ShardedClient, +}; +use crate::queue::{Entry, Queue}; +use async_trait::async_trait; +use nohash_hasher::IntMap; +use std::sync::Arc; +use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; +use text_generation_router::validation::ValidGenerateRequest; +use text_generation_router::{FinishReason, PrefillToken, Token}; +use tokio::sync::mpsc::error::SendError; +use tokio::sync::{mpsc, Notify}; +use tokio::time::Instant; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tracing::{info_span, instrument, Instrument, Span}; + +pub struct BackendV3 { + /// Request queue + queue: Queue, + /// Notify batcher on queue appends + batching_task_notifier: Arc, + /// Client clone, used for health checks to skip the queue + client: ShardedClient, +} + +impl BackendV3 { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + client: ShardedClient, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: u32, + max_waiting_tokens: usize, + max_batch_size: Option, + shard_info: InfoResponse, + ) -> Self { + if shard_info.support_chunking { + tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored."); + } + + let block_size = shard_info.block_size; + + let queue = Queue::new( + shard_info.requires_padding, + block_size, + shard_info.use_prefix_caching, + shard_info.window_size, + shard_info.speculate, + max_batch_total_tokens, + shard_info.support_chunking, + ); + let batching_task_notifier = Arc::new(Notify::new()); + + // Spawn batching background task that contains all the inference logic + tokio::spawn(batching_task( + client.clone(), + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + shard_info.support_chunking, + queue.clone(), + batching_task_notifier.clone(), + )); + + Self { + queue, + batching_task_notifier, + client, + } + } +} + +#[async_trait] +impl Backend for BackendV3 { + #[instrument(skip_all)] + fn schedule( + &self, + request: ValidGenerateRequest, + ) -> Result>, InferError> { + // MPSC channel to communicate with the background batching task + let (response_tx, response_rx) = mpsc::unbounded_channel(); + + // Append the request to the queue + self.queue.append(Entry { + request, + response_tx, + span: Span::current(), + temp_span: None, + queue_time: Instant::now(), + batch_time: None, + block_allocation: None, + }); + + // Notify the background task that we have a new entry in the queue that needs + // to be batched + self.batching_task_notifier.notify_one(); + + // Return stream + Ok(UnboundedReceiverStream::new(response_rx)) + } + + async fn health(&self, current_health: bool) -> bool { + if current_health { + // Generation is healthy, we only check that the shards can allocate on device + self.client.device_health().await + } else { + self.client.model_health().await + } + .is_ok() + } +} + +/// Batching logic +/// Will be launched in a background Tokio task +/// +/// Batches requests and sends them to the inference server +#[allow(clippy::too_many_arguments)] +pub(crate) async fn batching_task( + mut client: ShardedClient, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: u32, + max_waiting_tokens: usize, + max_batch_size: Option, + support_chunking: bool, + queue: Queue, + notifier: Arc, +) { + // Infinite loop + loop { + // Wait for a notification from the Infer struct + notifier.notified().await; + + // Get the next batch from the queue + // This batch might be smaller than the maximum batch size if there are not enough requests + // waiting in the queue + while let Some((mut entries, batch, span)) = queue + .next_batch( + None, + max_batch_size, + max_batch_prefill_tokens, + max_batch_total_tokens, + ) + .await + { + let mut cached_batch = prefill(&mut client, batch, None, &mut entries) + .instrument(span) + .await; + let mut waiting_tokens = 1; + + // We loop until we do not receive any cached batch from the inference server (== until + // all requests have met their stopping criteria) + while let Some(batch) = cached_batch { + // Get current batch info + let batch_size = batch.size; + let batch_max_tokens = batch.max_tokens; + let current_tokens = batch.current_tokens; + let mut batches = vec![batch]; + metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); + metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); + + let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); + + let (min_size, max_size, prefill_token_budget) = if support_chunking { + // Since the next batch will be concatenated with the current batch, + // the current batch tokens must be subtracted to the prefill budget + let prefill_token_budget = + max_batch_prefill_tokens.saturating_sub(current_tokens); + // We can ignore min_size and max_size + // Models than rely on max_size cannot support chunking + // Regarding min_size, chunking allow us to consistently run at the compute + // bound, making min_size useless. + (None, None, prefill_token_budget) + } else { + let min_size = if waiting_tokens >= max_waiting_tokens { + // If we didn't onboard any new requests since >= max_waiting_tokens, we try + // to add a new batch even though its size might be small + None + } else { + // Minimum batch size + // TODO: temporarily disable to avoid incorrect deallocation + + // reallocation when using prefix caching. + Some((batch_size as f32 * waiting_served_ratio).floor() as usize) + }; + + let max_size = + max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); + + (min_size, max_size, max_batch_prefill_tokens) + }; + + // Try to get a new batch + if let Some((new_entries, new_batch, span)) = queue + .next_batch(min_size, max_size, prefill_token_budget, token_budget) + .await + { + // Tracking metrics + if min_size.is_some() { + metrics::counter!("tgi_batch_concat", "reason" => "backpressure") + .increment(1); + } else { + let counter = if support_chunking { + metrics::counter!("tgi_batch_concat", "reason" => "chunking") + } else { + metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") + }; + counter.increment(1); + } + let cached_batch = if support_chunking { + // Concat current batch to the new one + batches.pop() + } else { + // Request are waiting only if we don't support chunking + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span to add the info that this entry is waiting + // because a new batch is being computed + let entry_waiting_span = info_span!(parent: &entry.span, "waiting"); + // Add relationships + span.follows_from(&entry_waiting_span); + entry_waiting_span.follows_from(&span); + // Update entry + entry.temp_span = Some(entry_waiting_span); + }); + None + }; + entries.extend(new_entries); + + // Generate one token for this new batch to have the attention past in cache + let new_cached_batch = + prefill(&mut client, new_batch, cached_batch, &mut entries) + .instrument(span) + .await; + // Reset waiting counter + waiting_tokens = 1; + // Extend current batch with the new batch + if let Some(new_cached_batch) = new_cached_batch { + batches.push(new_cached_batch); + } else if support_chunking { + // New cached batch is empty, no work left + break; + } + } + + // Create span for this batch to add context to inference calls + let next_batch_size = entries.len(); + let next_batch_span = + info_span!(parent: None, "batch", batch_size = next_batch_size); + entries.iter_mut().for_each(|(_, entry)| { + // Create a new span to link the batch back to this entry + let entry_batch_span = info_span!(parent: &entry.span, "infer"); + // Add relationships + next_batch_span.follows_from(&entry_batch_span); + entry_batch_span.follows_from(&next_batch_span); + // Update entry + entry.temp_span = Some(entry_batch_span); + }); + + cached_batch = decode(&mut client, batches, &mut entries) + .instrument(next_batch_span) + .await; + waiting_tokens += 1; + } + metrics::gauge!("tgi_batch_current_size").set(0.0); + metrics::gauge!("tgi_batch_current_max_tokens").set(0.0); + } + } +} + +#[instrument(skip_all)] +async fn prefill( + client: &mut ShardedClient, + batch: Batch, + cached_batch: Option, + entries: &mut IntMap, +) -> Option { + let start_time = Instant::now(); + let batch_id = batch.id; + metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); + + match client.prefill(batch, cached_batch).await { + Ok((generations, next_batch, timings)) => { + let start_filtering_time = Instant::now(); + // Send generated tokens and filter stopped entries + filter_send_generations(generations, entries); + + // Filter next batch and remove requests that were stopped + let next_batch = filter_batch(client, next_batch, entries).await; + + if let Some(concat_duration) = timings.concat { + metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") + .record(concat_duration.as_secs_f64()); + } + metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1); + next_batch + } + // If we have an error, we discard the whole batch + Err(err) => { + let _ = client.clear_cache(Some(batch_id)).await; + send_errors(err, entries); + metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1); + None + } + } +} + +#[instrument(skip_all)] +async fn decode( + client: &mut ShardedClient, + batches: Vec, + entries: &mut IntMap, +) -> Option { + let start_time = Instant::now(); + let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); + metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1); + + match client.decode(batches).await { + Ok((generations, next_batch, timings)) => { + let start_filtering_time = Instant::now(); + // Send generated tokens and filter stopped entries + filter_send_generations(generations, entries); + + // Filter next batch and remove requests that were stopped + let next_batch = filter_batch(client, next_batch, entries).await; + + if let Some(concat_duration) = timings.concat { + metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") + .record(concat_duration.as_secs_f64()); + } + metrics::histogram!("tgi_batch_forward_duration", "method" => "decode") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "decode") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "decode") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration", "method" => "decode") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1); + next_batch + } + // If we have an error, we discard the whole batch + Err(err) => { + for id in batch_ids { + let _ = client.clear_cache(Some(id)).await; + } + send_errors(err, entries); + metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1); + None + } + } +} + +/// Filter a `batch` and remove all requests not present in `entries` +#[instrument(skip_all)] +async fn filter_batch( + client: &mut ShardedClient, + next_batch: Option, + entries: &IntMap, +) -> Option { + let mut batch = next_batch?; + + // No need to filter + if batch.size as usize == entries.len() { + return Some(batch); + } + + let id = batch.id; + + // Retain only requests that are still in entries + batch.request_ids.retain(|id| entries.contains_key(id)); + + if batch.request_ids.is_empty() { + // All requests have been filtered out + // Next batch is now empty + // Clear it from the Python shards cache + // We unwrap here as we need to panic since we cannot recover if this method fails + client.clear_cache(Some(id)).await.unwrap(); + None + } else { + // Filter Python shard cache + // We unwrap here as we need to panic since we cannot recover if this method fails + client.filter_batch(id, batch.request_ids).await.unwrap() + } +} + +/// Send one or multiple `InferStreamResponse` to Infer for all `entries` +/// and filter entries +#[instrument(skip_all)] +fn filter_send_generations(generations: Vec, entries: &mut IntMap) { + generations.into_iter().for_each(|generation| { + let id = generation.request_id; + // Get entry + // We can `expect` here as the request id should always be in the entries + let entry = entries + .get(&id) + .expect("ID not found in entries. This is a bug."); + + // Create and enter a span to link this function back to the entry + let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered(); + // Send generation responses back to the infer task + // If the receive an error from the Flume channel, it means that the client dropped the + // request and we need to stop generating hence why we unwrap_or(true) + let stopped = send_responses(generation, entry).inspect_err(|_err| { + tracing::error!("Entry response channel error."); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); + }).unwrap_or(true); + if stopped { + entries.remove(&id).expect("ID not found in entries. This is a bug."); + } + }); +} + +/// Send responses through the `entry` response channel +fn send_responses( + generation: Generation, + entry: &Entry, +) -> Result>>> { + // Return directly if the channel is disconnected + if entry.response_tx.is_closed() { + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); + return Ok(true); + } + + let mut stopped = false; + + if let Some(prefill_tokens) = generation.prefill_tokens { + // Create Token objects + // We do that here instead of in the Python code as Rust for loops are faster + let prefill_tokens = prefill_tokens + .ids + .into_iter() + .zip(prefill_tokens.logprobs) + .zip(prefill_tokens.texts) + .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) + .collect(); + + // Send message + entry + .response_tx + .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; + } + + // Create last Token + let tokens_ = generation.tokens.expect("Non empty tokens in generation"); + let n = tokens_.ids.len(); + metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64); + let mut iterator = tokens_ + .ids + .into_iter() + .zip(tokens_.logprobs) + .zip(tokens_.texts) + .zip(tokens_.is_special) + .enumerate() + .peekable(); + while let Some((i, (((id, logprob), text), special))) = iterator.next() { + let token = Token { + id, + text, + logprob, + special, + }; + let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) { + top_tokens_ + .ids + .iter() + .zip(top_tokens_.logprobs.iter()) + .zip(top_tokens_.texts.iter()) + .zip(top_tokens_.is_special.iter()) + .map(|(((&id, &logprob), text), &special)| Token { + id, + text: text.to_string(), + logprob, + special, + }) + .collect() + } else { + vec![] + }; + match (&generation.generated_text, iterator.peek()) { + (Some(generated_text), None) => { + // Generation has ended + stopped = true; + // Send message + entry.response_tx.send(Ok(InferStreamResponse::End { + token, + top_tokens, + generated_text: GeneratedText::from(generated_text.clone()), + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + }))?; + } + _ => { + // Send message + entry + .response_tx + .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?; + } + } + } + + Ok(stopped) +} + +/// Send errors to Infer for all `entries` +#[instrument(skip_all)] +fn send_errors(error: ClientError, entries: &mut IntMap) { + entries.drain().for_each(|(_, entry)| { + // Create and enter a span to link this function back to the entry + let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); + let err = InferError::GenerationError(error.to_string()); + metrics::counter!("tgi_request_failure", "err" => "generation").increment(1); + tracing::error!("{err}"); + + // unwrap_or is valid here as we don't care if the receiver is gone. + entry + .response_tx + .send(Err(err)) + .unwrap_or(()); + }); +} + +impl From for GeneratedText { + fn from(value: crate::client::GeneratedText) -> Self { + let v3_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap(); + let finish_reason = match v3_finish_reason { + crate::client::FinishReason::Length => FinishReason::Length, + crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken, + crate::client::FinishReason::StopSequence => FinishReason::StopSequence, + }; + + Self { + text: value.text, + generated_tokens: value.generated_tokens, + finish_reason, + seed: value.seed, + } + } +} diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs new file mode 100644 index 0000000000000000000000000000000000000000..4fea172b65a062d0d6b130ffcfdc5815a25d289f --- /dev/null +++ b/backends/v3/src/block_allocator.rs @@ -0,0 +1,209 @@ +use std::sync::Arc; +use tokio::sync::{mpsc, oneshot}; + +use crate::radix::RadixAllocator; + +#[derive(Debug, Clone)] +pub struct BlockAllocation { + pub allocation_id: u64, + pub blocks: Vec, + pub slots: Vec, + + /// Prefix that was cached and for which the KV does not have to + /// be recomputed. + pub prefix_len: u32, + + pub(crate) block_allocator: Option, +} + +impl Drop for BlockAllocation { + fn drop(&mut self) { + if let Some(block_allocator) = self.block_allocator.as_mut() { + block_allocator.free(self.blocks.clone(), self.allocation_id) + } + } +} + +#[derive(Debug, Clone)] +pub struct BlockAllocator { + /// Channel to communicate with the background task + block_allocator: mpsc::UnboundedSender, +} + +impl BlockAllocator { + pub(crate) fn new( + max_batch_total_tokens: u32, + block_size: u32, + prefix_caching: bool, + window_size: Option, + ) -> Self { + // Create channel + let (sender, receiver) = mpsc::unbounded_channel(); + + // Launch background queue task + tokio::spawn(block_allocator_task( + max_batch_total_tokens / block_size, + block_size, + prefix_caching, + window_size, + receiver, + )); + + Self { + block_allocator: sender, + } + } + + pub(crate) async fn allocate( + &self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option { + let (response_sender, response_receiver) = oneshot::channel(); + self.block_allocator + .send(BlockAllocatorCommand::Allocate { + tokens, + prefill_tokens, + response_sender, + }) + .unwrap(); + + response_receiver.await.unwrap().map(|mut allocation| { + allocation.block_allocator = Some(self.clone()); + allocation + }) + } + + pub(crate) fn free(&self, blocks: Vec, allocation_id: u64) { + self.block_allocator + .send(BlockAllocatorCommand::Free { + allocation_id, + blocks, + }) + .unwrap(); + } +} + +async fn block_allocator_task( + blocks: u32, + block_size: u32, + prefix_caching: bool, + window_size: Option, + mut receiver: mpsc::UnboundedReceiver, +) { + let mut allocator: Box = if prefix_caching { + Box::new(RadixAllocator::new(block_size, blocks, window_size)) + } else { + Box::new(SimpleAllocator::new(blocks, block_size, window_size)) + }; + while let Some(cmd) = receiver.recv().await { + match cmd { + BlockAllocatorCommand::Free { + blocks, + allocation_id, + } => allocator.free(blocks, allocation_id), + BlockAllocatorCommand::Allocate { + tokens, + prefill_tokens, + response_sender, + } => { + response_sender + .send(allocator.allocate(tokens, prefill_tokens)) + .unwrap(); + } + } + } +} + +#[derive(Debug)] +enum BlockAllocatorCommand { + Free { + blocks: Vec, + allocation_id: u64, + }, + Allocate { + tokens: u32, + prefill_tokens: Option>>, + response_sender: oneshot::Sender>, + }, +} + +pub trait Allocator { + fn allocate( + &mut self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option; + + fn free(&mut self, blocks: Vec, allocation_id: u64); +} +pub struct SimpleAllocator { + free_blocks: Vec, + block_size: u32, + window_size: Option, +} + +impl SimpleAllocator { + fn new(blocks: u32, block_size: u32, window_size: Option) -> Self { + SimpleAllocator { + block_size, + // Block 0 is reserved for health checks + free_blocks: (1..blocks).collect(), + window_size, + } + } +} + +impl Allocator for SimpleAllocator { + fn allocate( + &mut self, + tokens: u32, + _prefill_tokens: Option>>, + ) -> Option { + // Apply window size + let (required_blocks, repeats) = { + let (tokens, repeats) = match self.window_size { + None => (tokens, 1), + Some(window_size) => { + let repeats = (tokens + window_size - 1) / window_size; + let tokens = core::cmp::min(tokens, window_size); + (tokens, repeats as usize) + } + }; + // Pad to a multiple of block size + let required_blocks = (tokens + self.block_size - 1) / self.block_size; + (required_blocks, repeats) + }; + + let tokens = tokens as usize; + if required_blocks > self.free_blocks.len() as u32 { + None + } else { + let blocks = self + .free_blocks + .split_off(self.free_blocks.len() - required_blocks as usize); + let mut slots = + Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize); + + 'slots: for block_id in blocks.repeat(repeats).iter() { + for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { + slots.push(s); + if slots.len() == tokens { + break 'slots; + } + } + } + Some(BlockAllocation { + allocation_id: 0, + blocks, + slots, + prefix_len: 0, + block_allocator: None, + }) + } + } + + fn free(&mut self, blocks: Vec, _allocation_id: u64) { + self.free_blocks.extend(blocks) + } +} diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs new file mode 100644 index 0000000000000000000000000000000000000000..fe810f24742fdc1da7aa004006a29757b4932087 --- /dev/null +++ b/backends/v3/src/client/grpc_client.rs @@ -0,0 +1,299 @@ +/// Single shard Client +use crate::client::{pb, Chunk}; +use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64}; +use base64::engine::general_purpose::STANDARD; +use base64::Engine; +use grpc_metadata::InjectTelemetryContext; +use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient; +use pb::generate::v3::*; +use std::cmp::min; +use std::time::Duration; +use tonic::transport::{Channel, Uri}; +use tracing::instrument; + +/// Text Generation Inference gRPC client +#[derive(Debug, Clone)] +pub struct Client { + stub: TextGenerationServiceClient, +} + +impl Client { + /// Returns a client connected to the given url + #[allow(dead_code)] + pub async fn connect(uri: Uri) -> Result { + let channel = Channel::builder(uri).connect().await?; + + Ok(Self { + stub: TextGenerationServiceClient::new(channel), + }) + } + + /// Returns a client connected to the given unix socket + pub async fn connect_uds(path: String) -> Result { + let channel = Channel::from_shared("http://[::]:50051".to_string()) + .unwrap() + .connect_with_connector(tower::service_fn(move |_: Uri| { + tokio::net::UnixStream::connect(path.clone()) + })) + .await?; + + Ok(Self { + stub: TextGenerationServiceClient::new(channel), + }) + } + + /// Returns a list of uris or unix sockets of all shards + #[instrument(skip(self))] + pub async fn service_discovery(&mut self) -> Result> { + let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context(); + let response = self.stub.service_discovery(request).await.map_err(|_| { + ClientError::Connection("Server does not support v3 interface".to_string()) + })?; + let urls = response + .into_inner() + .urls + .into_iter() + // Remove unix socket prefix + .map(|url| match url.strip_prefix("unix://") { + None => url, + Some(stripped_url) => stripped_url.to_string(), + }) + .collect(); + Ok(urls) + } + + /// Get model info + #[instrument(skip(self))] + pub async fn info(&mut self) -> Result { + let request = tonic::Request::new(InfoRequest {}).inject_context(); + let response = self.stub.info(request).await?.into_inner(); + Ok(response) + } + + /// Get model health + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let request = tonic::Request::new(HealthRequest {}).inject_context(); + let response = self.stub.health(request).await?.into_inner(); + Ok(response) + } + + /// Clear the past generations cache + #[instrument(skip(self))] + pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { + let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context(); + self.stub.clear_cache(request).await?; + Ok(()) + } + + /// Filter a cached batch + #[instrument(skip(self))] + pub async fn filter_batch( + &mut self, + batch_id: u64, + request_ids: Vec, + ) -> Result> { + let request = tonic::Request::new(FilterBatchRequest { + batch_id, + request_ids, + }) + .inject_context(); + let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); + Ok(filtered_batch.batch) + } + + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip_all)] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + max_batch_size: Option, + ) -> Result> { + let mut n_tokens = 0; + let mut requests = Vec::new(); + // Create requests + while n_tokens < max_prefill_tokens { + let truncate = min(max_input_length, max_prefill_tokens - n_tokens); + + let mut input_chunks = Vec::new(); + input_chunks + .push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into()); + if n_tokens == 0 { + input_chunks.push( + Chunk::Image(Image { + // Safe unwrap, because we control the data. + data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(), + mimetype: "image/jpeg;base64".to_string(), + }) + .into(), + ); + } + + // Send stringly-typed inputs for compatibility for backends that haven't + // been updated to support chunks. + + let mut inputs = String::new(); + inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); + if n_tokens == 0 { + // 1 request is enough to test vision heads. + // Sending images on other queries messes up easily with truncation. + inputs.push_str(&format!( + "![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})", + )); + } + + requests.push(Request { + id: 0, + inputs, + add_special_tokens: true, + input_chunks: Some(Input { + chunks: input_chunks, + }), + // We truncate the input on the server side to be sure that it has the correct size + truncate, + // Blocks and slots will be set on the server side if we use paged attention + blocks: vec![], + slots: vec![], + cache_len: 0, + chunk_len: None, + // Set sampling parameters to also take these ops into account in the max memory + parameters: Some(NextTokenChooserParameters { + temperature: 0.9, + top_k: 10, + top_p: 0.9, + typical_p: 0.9, + do_sample: false, + seed: 0, + repetition_penalty: 1.2, + frequency_penalty: 0.1, + watermark: true, + grammar: String::new(), + grammar_type: GrammarType::None as i32, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: max_total_tokens - truncate, + stop_sequences: vec![], + ignore_eos_token: true, + }), + prefill_logprobs: true, + top_n_tokens: 20, + adapter_id: None, + }); + n_tokens += max_input_length; + + // Check max_batch_size + if Some(requests.len()) == max_batch_size { + break; + } + } + + let batch = Batch { + id: 0, + size: requests.len() as u32, + requests, + max_tokens: max_input_length, + max_blocks: 0, + }; + + let request = tonic::Request::new(WarmupRequest { + batch: Some(batch), + max_input_length, + max_prefill_tokens, + max_total_tokens, + }) + .inject_context(); + let response = self.stub.warmup(request).await?.into_inner(); + Ok(response.max_supported_total_tokens) + } + + /// Generate one token for each request in the given batch + /// + /// Returns Generation for each request in batch + /// and the next cached batch + #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))] + pub async fn prefill( + &mut self, + batch: Batch, + cached_batch: Option, + ) -> Result<(Vec, Option, PrefillTimings)> { + let request = tonic::Request::new(PrefillRequest { + batch: Some(batch), + cached_batch, + }) + .inject_context(); + let response = self.stub.prefill(request).await?.into_inner(); + Ok(( + response.generations, + response.batch, + PrefillTimings::new( + response.concat_ns, + response.forward_ns, + response.decode_ns, + response.total_ns, + ), + )) + } + + /// Generate one token for each request in the given cached batches + /// + /// Returns Generation for each request in batches + /// and the next cached batch + #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::()))] + pub async fn decode( + &mut self, + batches: Vec, + ) -> Result<(Vec, Option, DecodeTimings)> { + let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); + let response = self.stub.decode(request).await?.into_inner(); + Ok(( + response.generations, + response.batch, + DecodeTimings::new( + response.concat_ns, + response.forward_ns, + response.decode_ns, + response.total_ns, + ), + )) + } +} + +pub struct PrefillTimings { + pub concat: Option, + pub forward: Duration, + pub decode: Duration, + pub total: Duration, +} + +impl PrefillTimings { + fn new(concat_ns: Option, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + Self { + concat: concat_ns.map(Duration::from_nanos), + forward: Duration::from_nanos(forward_ns), + decode: Duration::from_nanos(decode_ns), + total: Duration::from_nanos(total_ns), + } + } +} + +pub struct DecodeTimings { + pub concat: Option, + pub forward: Duration, + pub decode: Duration, + pub total: Duration, +} + +impl DecodeTimings { + fn new(concat_ns: Option, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self { + Self { + concat: concat_ns.map(Duration::from_nanos), + forward: Duration::from_nanos(forward_ns), + decode: Duration::from_nanos(decode_ns), + total: Duration::from_nanos(total_ns), + } + } +} diff --git a/backends/v3/src/client/mod.rs b/backends/v3/src/client/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..d4ac50c9c46c24a82a0655eae5a2bd8c4e3728dc --- /dev/null +++ b/backends/v3/src/client/mod.rs @@ -0,0 +1,67 @@ +//! Text Generation gRPC client library + +use async_trait::async_trait; +use thiserror::Error; +use tonic::transport; +use tonic::Status; + +#[allow(clippy::derive_partial_eq_without_eq)] +mod pb; + +mod grpc_client; +mod sharded_client; + +pub use grpc_client::Client; +pub use pb::generate::v3::{ + input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, + HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request, + StoppingCriteriaParameters, +}; +pub use sharded_client::ShardedClient; + +#[async_trait] +pub trait Health { + /// Check if a generate server is healthy by asking it to allocate a tensor on device + async fn device_health(&self) -> Result<()>; + + /// Check if a generate server is healthy by doing a forward pass. + /// EXPENSIVE + async fn model_health(&self) -> Result<()>; +} + +#[derive(Error, Debug, Clone)] +pub enum ClientError { + #[error("Could not connect to Text Generation server: {0}")] + Connection(String), + #[error("Server error: {0}")] + Generation(String), + #[error("Sharded results are empty")] + EmptyResults, +} + +impl From for ClientError { + fn from(err: Status) -> Self { + let err = Self::Generation(err.message().to_string()); + tracing::error!("{err}"); + err + } +} + +impl From for ClientError { + fn from(err: transport::Error) -> Self { + let err = Self::Connection(err.to_string()); + tracing::error!("{err}"); + err + } +} + +// Small convenience re-wrapping of `Chunk`. +impl From for InputChunk { + fn from(chunk: Chunk) -> Self { + InputChunk { chunk: Some(chunk) } + } +} + +static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; + +pub type Result = std::result::Result; diff --git a/router/client/src/v3/pb/generate.v3.rs b/backends/v3/src/client/pb/generate.v3.rs similarity index 95% rename from router/client/src/v3/pb/generate.v3.rs rename to backends/v3/src/client/pb/generate.v3.rs index 72315ea31ecf55b2aff4cdbc865859833d8e6633..27b45d191ba9131b6ee452c488b91b3aa7abef11 100644 --- a/router/client/src/v3/pb/generate.v3.rs +++ b/backends/v3/src/client/pb/generate.v3.rs @@ -22,6 +22,14 @@ pub struct InfoResponse { pub window_size: ::core::option::Option, #[prost(uint32, tag = "5")] pub speculate: u32, + #[prost(bool, tag = "6")] + pub support_chunking: bool, + #[prost(bool, tag = "7")] + pub use_prefix_caching: bool, + #[prost(string, tag = "8")] + pub attention_impl: ::prost::alloc::string::String, + #[prost(uint32, tag = "9")] + pub block_size: u32, } /// / Empty request #[allow(clippy::derive_partial_eq_without_eq)] @@ -167,6 +175,17 @@ pub struct Request { /// / LORA adapter index #[prost(string, optional, tag = "11")] pub adapter_id: ::core::option::Option<::prost::alloc::string::String>, + /// / Tokens that can be retrieved from the KV cache. + /// / This value is set for the first prefill and never reset + #[prost(uint32, tag = "12")] + pub cache_len: u32, + /// / Context truncation + #[prost(bool, tag = "13")] + pub add_special_tokens: bool, + /// / Chunk of tokens that must be computed for the first prefill + /// / This value is set for the first prefill and never reset + #[prost(uint32, optional, tag = "14")] + pub chunk_len: ::core::option::Option, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -202,6 +221,9 @@ pub struct CachedBatch { /// / Maximum number of tokens this batch will grow to #[prost(uint32, tag = "4")] pub max_tokens: u32, + /// / Number of tokens in the next forward + #[prost(uint32, tag = "5")] + pub current_tokens: u32, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -276,6 +298,9 @@ pub struct PrefillRequest { /// / Batch #[prost(message, optional, tag = "1")] pub batch: ::core::option::Option, + /// / Optional cached batch + #[prost(message, optional, tag = "2")] + pub cached_batch: ::core::option::Option, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -295,6 +320,9 @@ pub struct PrefillResponse { /// / Total elapsed time in nanoseconds #[prost(uint64, tag = "5")] pub total_ns: u64, + /// / Concatenate elapsed time in nanoseconds + #[prost(uint64, optional, tag = "6")] + pub concat_ns: ::core::option::Option, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/router/client/src/v3/pb/mod.rs b/backends/v3/src/client/pb/mod.rs similarity index 100% rename from router/client/src/v3/pb/mod.rs rename to backends/v3/src/client/pb/mod.rs diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs new file mode 100644 index 0000000000000000000000000000000000000000..e181cd28d2fdba0809569dc4321f209261ba65c6 --- /dev/null +++ b/backends/v3/src/client/sharded_client.rs @@ -0,0 +1,252 @@ +use crate::client::Health; +/// Multi shard Client +use crate::client::{ClientError, Result}; + +use crate::client::grpc_client::{DecodeTimings, PrefillTimings}; +use crate::client::{ + Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, +}; +use crate::client::{Chunk, InfoResponse, Input}; +use async_trait::async_trait; +use futures::future::join_all; +use tonic::transport::Uri; +use tracing::instrument; + +#[derive(Debug, Clone)] +/// Text Generation Inference gRPC multi client +pub struct ShardedClient { + clients: Vec, +} + +impl ShardedClient { + fn new(clients: Vec) -> Self { + Self { clients } + } + + /// Create a new ShardedClient from a master client. The master client will communicate with + /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method. + async fn from_master_client(mut master_client: Client) -> Result { + // Get all uris/unix sockets from the master client + let uris = master_client.service_discovery().await?; + let futures = uris.into_iter().map(Client::connect_uds); + let clients: Result> = join_all(futures).await.into_iter().collect(); + Ok(Self::new(clients?)) + } + + /// Returns a client connected to the given uri + #[allow(dead_code)] + pub async fn connect(uri: Uri) -> Result { + let master_client = Client::connect(uri).await?; + Self::from_master_client(master_client).await + } + + /// Returns a client connected to the given unix socket + pub async fn connect_uds(path: String) -> Result { + let master_client = Client::connect_uds(path).await?; + Self::from_master_client(master_client).await + } + + /// Get the model info + #[instrument(skip(self))] + pub async fn info(&mut self) -> Result { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.info()) + .collect(); + join_all(futures).await.pop().unwrap() + } + + /// GRPC health check + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.health()) + .collect(); + join_all(futures).await.pop().unwrap() + } + + /// Clear the past generations cache + #[instrument(skip(self))] + pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.clear_cache(batch_id)) + .collect(); + join_all(futures).await.into_iter().collect() + } + + /// Filter a cached batch + #[instrument(skip(self))] + pub async fn filter_batch( + &mut self, + batch_id: u64, + request_ids: Vec, + ) -> Result> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone()))) + .collect(); + // all shards return the same message + join_all(futures).await.pop().unwrap() + } + + /// Warmup on a max size batch + /// + /// Returns the maximum amount of tokens supported by the hardware + #[instrument(skip(self))] + pub async fn warmup( + &mut self, + max_input_length: u32, + max_prefill_tokens: u32, + max_total_tokens: u32, + max_batch_size: Option, + ) -> Result> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| { + Box::pin(client.warmup( + max_input_length, + max_prefill_tokens, + max_total_tokens, + max_batch_size, + )) + }) + .collect(); + // Take the minimum value + let results = join_all(futures) + .await + .into_iter() + .collect::>>>()?; + Ok(results.into_iter().flatten().min()) + } + + /// Generate one token for each request in the given batch + /// + /// Returns Generation for each request in batch + /// and the next cached batch + #[instrument(skip_all, fields(id = & batch.id, size = & batch.size))] + pub async fn prefill( + &mut self, + batch: Batch, + cached_batch: Option, + ) -> Result<(Vec, Option, PrefillTimings)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone()))) + .collect(); + #[allow(clippy::type_complexity)] + let results: Result, Option, PrefillTimings)>> = + join_all(futures).await.into_iter().collect(); + let mut results = results?; + + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; + + // Merge generations from different model shards + for (mut shard_generations, _, shard_timings) in results.into_iter() { + generations.append(&mut shard_generations); + // Return the timings of the slowest shard + if shard_timings.total > timings.total { + timings = shard_timings; + } + } + Ok((generations, next_batch, timings)) + } + + /// Generate one token for each request in the given cached batches + /// + /// Returns Generation for each request in batches + /// and the next cached batch + #[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))] + pub async fn decode( + &mut self, + batches: Vec, + ) -> Result<(Vec, Option, DecodeTimings)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.decode(batches.clone()))) + .collect(); + #[allow(clippy::type_complexity)] + let results: Result, Option, DecodeTimings)>> = + join_all(futures).await.into_iter().collect(); + let mut results = results?; + + let (mut generations, next_batch, mut timings) = + results.pop().ok_or(ClientError::EmptyResults)?; + + // Merge generations from different model shards + for (mut shard_generations, _, shard_timings) in results.into_iter() { + generations.append(&mut shard_generations); + // Return the timings of the slowest shard + if shard_timings.total > timings.total { + timings = shard_timings; + } + } + Ok((generations, next_batch, timings)) + } +} + +#[async_trait] +impl Health for ShardedClient { + async fn device_health(&self) -> Result<()> { + self.clone().health().await?; + Ok(()) + } + + async fn model_health(&self) -> Result<()> { + // Dummy batch of 1 token and 1 generated token + let liveness_request = Request { + id: u64::MAX, + inputs: "liveness".to_string(), + input_chunks: Some(Input { + chunks: vec![Chunk::Text("liveness".into()).into()], + }), + truncate: 10, + add_special_tokens: true, + prefill_logprobs: false, + parameters: Some(NextTokenChooserParameters { + temperature: 1.0, + top_k: 0, + top_p: 1.0, + typical_p: 1.0, + do_sample: false, + seed: 0, + repetition_penalty: 1.0, + frequency_penalty: 0.0, + watermark: false, + grammar: String::new(), + grammar_type: GrammarType::None as i32, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: 1, + stop_sequences: vec![], + ignore_eos_token: false, + }), + top_n_tokens: 0, + // Block 0 is reserved for health checks + blocks: vec![0], + slots: (0..16).collect(), + cache_len: 0, + adapter_id: None, + chunk_len: None, + }; + let batch = Batch { + id: u64::MAX, + requests: vec![liveness_request], + size: 1, + max_tokens: 2, + max_blocks: 1, + }; + self.clone().prefill(batch, None).await?; + Ok(()) + } +} diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs new file mode 100644 index 0000000000000000000000000000000000000000..7daf9eaeca7ae7a52bffce5d872ed93c67b9d4c8 --- /dev/null +++ b/backends/v3/src/lib.rs @@ -0,0 +1,154 @@ +mod backend; +pub mod block_allocator; +mod client; +mod queue; +pub mod radix; + +use crate::client::{ClientError, ShardedClient}; +pub(crate) use backend::BackendV3; +use serde::Serialize; +use thiserror::Error; +use utoipa::ToSchema; + +#[derive(Clone, Debug, Serialize, ToSchema)] +pub struct BackendInfo { + /// Mandatory + #[schema(example = "cuda")] + pub model_device_type: String, + #[schema(example = "torch.float16")] + pub model_dtype: String, + + /// Backend parameters + #[schema(example = "1")] + pub speculate: usize, + #[schema(example = "1.2")] + pub waiting_served_ratio: f32, + #[schema(example = "32000")] + pub max_batch_total_tokens: u32, + #[schema(example = "20")] + pub max_waiting_tokens: usize, + #[schema(nullable = true, example = "null")] + pub max_batch_size: Option, + #[schema(example = "false")] + pub support_chunking: bool, + #[schema(example = "false")] + pub prefix_caching: bool, + #[schema(example = "flashinfer")] + pub attention_impl: String, + #[schema(example = "1")] + pub block_size: u32, +} + +#[allow(clippy::too_many_arguments)] +pub async fn connect_backend( + max_input_tokens: usize, + max_total_tokens: usize, + master_shard_uds_path: String, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: Option, + max_waiting_tokens: usize, + max_batch_size: Option, +) -> Result<(BackendV3, BackendInfo), V3Error> { + // Helper function + let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option| { + match max_supported_batch_total_tokens { + // Older models do not support automatic max-batch-total-tokens + None => { + let max_batch_total_tokens = max_batch_total_tokens + .unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))); + tracing::warn!("Model does not support automatic max batch total tokens"); + Ok(max_batch_total_tokens) + } + // Flash attention models return their max supported total tokens + Some(max_supported_batch_total_tokens) => { + // Warn if user added his own max-batch-total-tokens as we will ignore it + if max_batch_total_tokens.is_some() { + tracing::warn!( + "`--max-batch-total-tokens` is deprecated for Flash \ + Attention models." + ); + tracing::warn!( + "Inferred max batch total tokens: {max_supported_batch_total_tokens}" + ); + } + if max_total_tokens as u32 > max_supported_batch_total_tokens { + return Err(V3Error::NotEnoughMemory(max_total_tokens)); + } + + Ok(max_supported_batch_total_tokens) + } + } + }; + + let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) + .await + .map_err(V3Error::Connection)?; + + // server is running on v3 + // Clear the cache; useful if the webserver rebooted + sharded_client + .clear_cache(None) + .await + .map_err(V3Error::Cache)?; + // Get info from the shard + let shard_info = sharded_client.info().await.map_err(V3Error::Info)?; + + // Warmup model + tracing::info!("Warming up model"); + let max_batch_total_tokens = check_max_batch_total_tokens( + sharded_client + .warmup( + max_input_tokens as u32, + max_batch_prefill_tokens, + max_total_tokens as u32, + max_batch_size, + ) + .await + .map_err(V3Error::Warmup)?, + )?; + tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); + metrics::gauge!("tgi_batch_max_total_tokens").set(max_batch_total_tokens); + + let backend_info = BackendInfo { + waiting_served_ratio, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + model_device_type: shard_info.device_type.clone(), + model_dtype: shard_info.dtype.clone(), + speculate: shard_info.speculate as usize, + support_chunking: shard_info.support_chunking, + prefix_caching: shard_info.use_prefix_caching, + attention_impl: shard_info.attention_impl.clone(), + block_size: shard_info.block_size, + }; + + let backend = BackendV3::new( + sharded_client, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + shard_info, + ); + + tracing::info!("Using backend V3"); + + Ok((backend, backend_info)) +} + +#[derive(Debug, Error)] +pub enum V3Error { + #[error("Unable to clear the Python model shards cache: {0}")] + Cache(ClientError), + #[error("Unable to connect to the Python model shards: {0}")] + Connection(ClientError), + #[error("Unable to get the Python model shards info: {0}")] + Info(ClientError), + #[error("Unable to warmup the Python model shards: {0}")] + Warmup(ClientError), + #[error("Not enough memory to handle `max_total_tokens={0}`")] + NotEnoughMemory(usize), +} diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs new file mode 100644 index 0000000000000000000000000000000000000000..bc4bdb934ebd1115b17a333b0a11ddb0ccd59e23 --- /dev/null +++ b/backends/v3/src/main.rs @@ -0,0 +1,212 @@ +use clap::{Parser, Subcommand}; +use text_generation_router::{server, usage_stats}; +use text_generation_router_v3::{connect_backend, V3Error}; +use thiserror::Error; + +/// App Configuration +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + #[command(subcommand)] + command: Option, + + #[clap(default_value = "128", long, env)] + max_concurrent_requests: usize, + #[clap(default_value = "2", long, env)] + max_best_of: usize, + #[clap(default_value = "4", long, env)] + max_stop_sequences: usize, + #[clap(default_value = "5", long, env)] + max_top_n_tokens: u32, + #[clap(default_value = "1024", long, env)] + max_input_tokens: usize, + #[clap(default_value = "2048", long, env)] + max_total_tokens: usize, + #[clap(default_value = "1.2", long, env)] + waiting_served_ratio: f32, + #[clap(default_value = "4096", long, env)] + max_batch_prefill_tokens: u32, + #[clap(long, env)] + max_batch_total_tokens: Option, + #[clap(default_value = "20", long, env)] + max_waiting_tokens: usize, + #[clap(long, env)] + max_batch_size: Option, + #[clap(default_value = "0.0.0.0", long, env)] + hostname: String, + #[clap(default_value = "3000", long, short, env)] + port: u16, + #[clap(default_value = "/tmp/text-generation-server-0", long, env)] + master_shard_uds_path: String, + #[clap(default_value = "bigscience/bloom", long, env)] + tokenizer_name: String, + #[clap(long, env)] + tokenizer_config_path: Option, + #[clap(long, env)] + revision: Option, + #[clap(long, env, value_enum)] + trust_remote_code: bool, + #[clap(default_value = "2", long, env)] + validation_workers: usize, + #[clap(long, env)] + api_key: Option, + #[clap(long, env)] + json_output: bool, + #[clap(long, env)] + otlp_endpoint: Option, + #[clap(default_value = "text-generation-inference.router", long, env)] + otlp_service_name: String, + #[clap(long, env)] + cors_allow_origin: Option>, + #[clap(long, env)] + ngrok: bool, + #[clap(long, env)] + ngrok_authtoken: Option, + #[clap(long, env)] + ngrok_edge: Option, + #[clap(long, env, default_value_t = false)] + disable_grammar_support: bool, + #[clap(default_value = "4", long, env)] + max_client_batch_size: usize, + #[clap(default_value = "on", long, env)] + usage_stats: usage_stats::UsageStatsLevel, +} + +#[derive(Debug, Subcommand)] +enum Commands { + PrintSchema, +} + +#[tokio::main] +async fn main() -> Result<(), RouterError> { + // Get args + let args = Args::parse(); + // Pattern match configuration + let Args { + command, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + hostname, + port, + master_shard_uds_path, + tokenizer_name, + tokenizer_config_path, + revision, + trust_remote_code, + validation_workers, + api_key, + json_output, + otlp_endpoint, + otlp_service_name, + cors_allow_origin, + ngrok, + ngrok_authtoken, + ngrok_edge, + disable_grammar_support, + max_client_batch_size, + usage_stats, + } = args; + + if let Some(Commands::PrintSchema) = command { + use utoipa::OpenApi; + let api_doc = text_generation_router::server::ApiDoc::openapi(); + let api_doc = serde_json::to_string_pretty(&api_doc).unwrap(); + println!("{}", api_doc); + std::process::exit(0); + }; + text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output); + + // Validate args + if max_input_tokens >= max_total_tokens { + return Err(RouterError::ArgumentValidation( + "`max_input_tokens` must be < `max_total_tokens`".to_string(), + )); + } + + if validation_workers == 0 { + return Err(RouterError::ArgumentValidation( + "`validation_workers` must be > 0".to_string(), + )); + } + if let Some(max_batch_size) = max_batch_size { + if max_batch_size == 0 { + return Err(RouterError::ArgumentValidation( + "`max_batch_size` must be > 0".to_string(), + )); + } + } + + let (backend, backend_info) = connect_backend( + max_input_tokens, + max_total_tokens, + master_shard_uds_path, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + ) + .await?; + + // Validate remaining args now that the backend is known + let support_chunking = backend_info.support_chunking; + let max_batch_total_tokens = backend_info.max_batch_total_tokens; + if max_input_tokens as u32 > max_batch_prefill_tokens && !support_chunking { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); + } + if max_batch_prefill_tokens > max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); + } + if max_total_tokens as u32 > max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); + } + + // Run server + server::run( + backend, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + validation_workers, + api_key, + tokenizer_name, + tokenizer_config_path, + revision, + trust_remote_code, + hostname, + port, + cors_allow_origin, + ngrok, + ngrok_authtoken, + ngrok_edge, + disable_grammar_support, + max_client_batch_size, + usage_stats, + ) + .await?; + Ok(()) +} + +#[derive(Debug, Error)] +enum RouterError { + #[error("Argument validation error: {0}")] + ArgumentValidation(String), + #[error("Backend failed: {0}")] + Backend(#[from] V3Error), + #[error("WebServer error: {0}")] + WebServer(#[from] server::WebServerError), + #[error("Tokio runtime failed to start: {0}")] + Tokio(#[from] std::io::Error), +} diff --git a/router/src/infer/v3/queue.rs b/backends/v3/src/queue.rs similarity index 74% rename from router/src/infer/v3/queue.rs rename to backends/v3/src/queue.rs index ba65b9b6a8c1005e9acca8448bf8fa328bbb6d44..6662b8de1f9240ab4e22d2b27dc65e8961bb74ba 100644 --- a/router/src/infer/v3/queue.rs +++ b/backends/v3/src/queue.rs @@ -1,17 +1,17 @@ -use crate::infer::v3::block_allocator::{BlockAllocation, BlockAllocator}; -use crate::infer::InferError; -use crate::infer::InferStreamResponse; -use crate::validation::{ - ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, +use crate::block_allocator::{BlockAllocation, BlockAllocator}; +use crate::client; +use crate::client::{ + Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; use nohash_hasher::{BuildNoHashHasher, IntMap}; -use std::cmp::{max, min}; +use std::cmp::max; use std::collections::VecDeque; -use text_generation_client::v3::{ - Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, +use text_generation_router::infer::InferError; +use text_generation_router::infer::InferStreamResponse; +use text_generation_router::validation::{ + Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, + ValidStoppingParameters, }; -use text_generation_client::ChunksToString; -use text_generation_client::Input; use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; use tracing::{info_span, instrument, Instrument, Span}; @@ -46,9 +46,11 @@ impl Queue { pub(crate) fn new( requires_padding: bool, block_size: u32, + prefix_caching: bool, window_size: Option, speculate: u32, max_batch_total_tokens: u32, + support_chunking: bool, ) -> Self { // Create channel let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); @@ -57,9 +59,11 @@ impl Queue { tokio::spawn(queue_task( requires_padding, block_size, + prefix_caching, window_size, speculate, max_batch_total_tokens, + support_chunking, queue_receiver, )); @@ -85,6 +89,10 @@ impl Queue { prefill_token_budget: u32, token_budget: u32, ) -> Option { + if prefill_token_budget == 0 || token_budget == 0 { + return None; + }; + // Create response channel let (response_sender, response_receiver) = oneshot::channel(); // Send next batch command to the background task managing the state @@ -106,27 +114,32 @@ impl Queue { } // Background task responsible of the queue state +#[allow(clippy::too_many_arguments)] async fn queue_task( requires_padding: bool, block_size: u32, + prefix_caching: bool, window_size: Option, speculate: u32, max_batch_total_tokens: u32, + support_chunking: bool, mut receiver: mpsc::UnboundedReceiver, ) { let mut state = State::new( requires_padding, block_size, + prefix_caching, window_size, speculate, max_batch_total_tokens, + support_chunking, ); while let Some(cmd) = receiver.recv().await { match cmd { QueueCommand::Append(entry, span) => { span.in_scope(|| state.append(*entry)); - metrics::increment_gauge!("tgi_queue_size", 1.0); + metrics::gauge!("tgi_queue_size").increment(1.0); } QueueCommand::NextBatch { min_size, @@ -141,7 +154,7 @@ async fn queue_task( .instrument(span) .await; response_sender.send(next_batch).unwrap(); - metrics::gauge!("tgi_queue_size", state.entries.len() as f64); + metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); } } } @@ -162,12 +175,14 @@ struct State { /// Paged Attention block size block_size: u32, - /// Sliding window - window_size: Option, - /// Speculation amount speculate: u32, + /// Whether the model allow the prefill chunking + /// If it does, the last request in the batch will be split to exactly match the prefill + /// token budget + support_chunking: bool, + /// Paged Attention Block Allocation block_allocator: Option, } @@ -176,20 +191,28 @@ impl State { fn new( requires_padding: bool, block_size: u32, + prefix_caching: bool, window_size: Option, speculate: u32, max_batch_total_tokens: u32, + support_chunking: bool, ) -> Self { - let block_allocator = (!requires_padding) - .then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size)); + let block_allocator = (!requires_padding).then(|| { + BlockAllocator::new( + max_batch_total_tokens, + block_size, + prefix_caching, + window_size, + ) + }); Self { entries: VecDeque::with_capacity(128), next_id: 0, next_batch_id: 0, block_size, - window_size, speculate, + support_chunking, block_allocator, } } @@ -226,29 +249,33 @@ impl State { } } + if let Some(max_size) = max_size { + if max_size == 0 { + tracing::debug!("No capacity"); + return None; + } + } + // Pad prefill_token_budget to be a multiple of block size let prefill_token_budget = ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size; // Create span for this batch to add context to inference calls let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); - next_batch_span.follows_from(&Span::current()); - - let mut batch_requests = Vec::with_capacity(self.entries.len()); - let mut batch_entries = - IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); + next_batch_span.follows_from(Span::current()); + let mut batch = Vec::with_capacity(self.entries.len()); let mut max_input_length = 0; let mut prefill_tokens: u32 = 0; let mut decode_tokens: u32 = 0; let mut max_blocks = 0; // Pop entries starting from the front of the queue - 'entry_loop: while let Some((id, mut entry)) = self.entries.pop_front() { + 'entry_loop: while let Some((id, entry)) = self.entries.pop_front() { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) if entry.response_tx.is_closed() { - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); tracing::debug!("Dropping entry"); continue; } @@ -258,7 +285,7 @@ impl State { // We pad to max input length in the Python shards // We need to take these padding tokens into the equation max_input_length = max_input_length.max(entry.request.input_length); - prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length; + prefill_tokens = (batch.len() + 1) as u32 * max_input_length; decode_tokens += entry.request.stopping_parameters.max_new_tokens; let total_tokens = prefill_tokens + decode_tokens + self.speculate; @@ -273,32 +300,21 @@ impl State { None } Some(block_allocator) => { - prefill_tokens += entry.request.input_length; - let max_new_tokens = match self.window_size { - None => entry.request.stopping_parameters.max_new_tokens, - Some(window_size) => min( - window_size.saturating_sub(entry.request.input_length), - entry.request.stopping_parameters.max_new_tokens, - ), + // If users wants the prefill logprobs, we cannot reuse the cache. + // So no input_ids for the radix tree. + let input_ids = if entry.request.decoder_input_details { + None + } else { + entry.request.input_ids.clone() }; - decode_tokens += max_new_tokens; - - if prefill_tokens > prefill_token_budget - || (prefill_tokens + decode_tokens + self.speculate) > token_budget - { - // Entry is over budget - // Add it back to the front - tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); - self.entries.push_front((id, entry)); - break; - } let tokens = entry.request.input_length + entry.request.stopping_parameters.max_new_tokens + self.speculate - 1; + tracing::debug!("Allocating {tokens} with {input_ids:?}"); - match block_allocator.allocate(tokens).await { + let block_allocation = match block_allocator.allocate(tokens, input_ids).await { None => { // Entry is over budget // Add it back to the front @@ -306,16 +322,88 @@ impl State { self.entries.push_front((id, entry)); break 'entry_loop; } - Some(block_allocation) => { + Some(mut block_allocation) => { tracing::debug!("Allocation: {block_allocation:?}"); max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); - Some(block_allocation) + + if block_allocation.prefix_len == entry.request.input_length { + // The whole request was found in the radix trie + // However, for the transformer forward to work, we need to + // have at least one token of postfix. + block_allocation.prefix_len -= 1; + } + + block_allocation + } + }; + + let postfix_len = entry.request.input_length - block_allocation.prefix_len; + + if prefill_tokens + postfix_len > prefill_token_budget { + // Entry is over budget + if self.support_chunking { + // We support chunking, just set postfix_len to exactly match prefill_token_budget + let chunk_len = prefill_token_budget.saturating_sub(prefill_tokens); + if chunk_len > 0 { + // Push this entry inside the batch + batch.push((id, entry, Some(block_allocation), Some(chunk_len))); + } else { + // We cannot prefill even one token for this entry + // Add it back to the queue + self.entries.push_front((id, entry)); + } + tracing::debug!( + "Matched budget: prefill_tokens={} == {prefill_token_budget}", + prefill_tokens + postfix_len + ); + break 'entry_loop; + } else { + // We don't support chunking, this entry needs to go back to the buffer + // Add it back to the front + tracing::debug!( + "Over budget: prefill_tokens={} > {prefill_token_budget}", + prefill_tokens + postfix_len + ); + self.entries.push_front((id, entry)); + break 'entry_loop; } } + + prefill_tokens += postfix_len; + + Some(block_allocation) } }; + batch.push((id, entry, block_allocation, None)); + if Some(batch.len()) == max_size { + break; + } + } - tracing::debug!("Accepting entry"); + // Empty batch + if batch.is_empty() { + tracing::debug!("Filterered out all entries"); + return None; + } + + // XXX We haven't allocated yet, so we're allowed to ditch the results. + // Check if our batch is big enough + if let Some(min_size) = min_size { + // Batch is too small + if batch.len() < min_size { + // Add back entries to the queue in the correct order + for (id, entry, _, _) in batch.into_iter().rev() { + self.entries.push_front((id, entry)); + } + return None; + } + } + + let mut batch_requests = Vec::with_capacity(self.entries.len()); + let mut batch_entries = + IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); + + for (id, mut entry, block_allocation, chunk_len) in batch { // Create a new span to link the batch back to this entry let entry_batch_span = info_span!(parent: &entry.span, "infer"); // Add relationships @@ -324,11 +412,12 @@ impl State { // Update entry entry.temp_span = Some(entry_batch_span); - let (blocks, slots) = match &block_allocation { - None => (Vec::new(), Vec::new()), + let (blocks, slots, prefix_len) = match &block_allocation { + None => (Vec::new(), Vec::new(), 0), Some(block_allocation) => ( block_allocation.blocks.clone(), block_allocation.slots.clone(), + block_allocation.prefix_len, ), }; @@ -337,11 +426,26 @@ impl State { batch_requests.push(Request { id, prefill_logprobs: entry.request.decoder_input_details, - input_chunks: Some(Input { - chunks: entry.request.inputs.clone(), + input_chunks: Some(client::Input { + chunks: entry + .request + .inputs + .clone() + .into_iter() + .map(|c| client::InputChunk { + chunk: Some(match c { + Chunk::Text(text) => client::Chunk::Text(text), + Chunk::Image(image) => client::Chunk::Image(client::Image { + data: image.data, + mimetype: image.mimetype, + }), + }), + }) + .collect(), }), inputs: entry.request.inputs.chunks_to_string(), truncate: entry.request.truncate, + add_special_tokens: entry.request.add_special_tokens, parameters: Some(NextTokenChooserParameters::from( entry.request.parameters.clone(), )), @@ -351,38 +455,14 @@ impl State { top_n_tokens: entry.request.top_n_tokens, blocks, slots, + cache_len: prefix_len, adapter_id: entry.request.adapter_id.clone(), + chunk_len, }); // Set batch_time entry.batch_time = Some(Instant::now()); // Insert in batch_entries IntMap batch_entries.insert(id, entry); - - // Check if max_size - if Some(batch_requests.len()) == max_size { - break; - } - } - - // Empty batch - if batch_requests.is_empty() { - tracing::debug!("Filterered out all entries"); - return None; - } - - // Check if our batch is big enough - if let Some(min_size) = min_size { - // Batch is too small - if batch_requests.len() < min_size { - // Add back entries to the queue in the correct order - for r in batch_requests.into_iter().rev() { - let id = r.id; - let entry = batch_entries.remove(&id).unwrap(); - self.entries.push_front((id, entry)); - } - - return None; - } } // Final batch size @@ -399,7 +479,7 @@ impl State { // Increment batch id self.next_batch_id += 1; - metrics::histogram!("tgi_batch_next_size", batch.size as f64); + metrics::histogram!("tgi_batch_next_size").record(batch.size as f64); Some((batch_entries, batch, next_batch_span)) } @@ -459,6 +539,8 @@ impl From for StoppingCriteriaParameters { #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use tracing::info_span; @@ -471,7 +553,9 @@ mod tests { let entry = Entry { request: ValidGenerateRequest { inputs: vec![], - input_length: 0, + input_ids: Some(Arc::new(vec![])), + input_length: 1, + add_special_tokens: true, truncate: 0, decoder_input_details: false, parameters: ValidParameters { @@ -506,7 +590,7 @@ mod tests { #[tokio::test] async fn test_append() { - let mut state = State::new(false, 1, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16, false); let (entry, _guard) = default_entry(); assert_eq!(state.next_id, 0); @@ -522,7 +606,7 @@ mod tests { #[tokio::test] async fn test_next_batch_empty() { - let mut state = State::new(false, 1, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16, false); assert!(state.next_batch(None, None, 1, 1).await.is_none()); assert!(state.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -530,7 +614,7 @@ mod tests { #[tokio::test] async fn test_next_batch_min_size() { - let mut state = State::new(false, 1, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -562,7 +646,7 @@ mod tests { #[tokio::test] async fn test_next_batch_max_size() { - let mut state = State::new(false, 1, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -582,7 +666,7 @@ mod tests { #[tokio::test] async fn test_next_batch_token_budget() { - let mut state = State::new(false, 1, None, 0, 2); + let mut state = State::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -615,14 +699,14 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16, false); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16, false); assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -630,7 +714,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -663,7 +747,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_max_size() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -679,7 +763,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -704,7 +788,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_speculate() { - let queue = Queue::new(false, 1, None, 2, 16); + let queue = Queue::new(true, 1, false, None, 2, 16, false); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -723,7 +807,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16, false); let (entry, _) = default_entry(); queue.append(entry); diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs new file mode 100644 index 0000000000000000000000000000000000000000..8a544891199228fb61249e37effe0ee94fc42271 --- /dev/null +++ b/backends/v3/src/radix.rs @@ -0,0 +1,876 @@ +use crate::block_allocator::{Allocator, BlockAllocation}; +use slotmap::{DefaultKey, SlotMap}; +use std::hash::{Hash, Hasher}; +use std::{ + collections::{BTreeSet, HashMap}, + sync::Arc, +}; + +fn hash(slice: &[u32]) -> u64 { + assert!(!slice.is_empty()); + if slice.len() == 1 { + slice[0] as u64 + } else { + let mut s = std::hash::DefaultHasher::new(); + slice.hash(&mut s); + s.finish() + } +} + +pub struct RadixAllocator { + allocation_id: u64, + + allocations: HashMap, + + cache_blocks: RadixTrie, + + /// Blocks that are immediately available for allocation. + free_blocks: Vec, + + #[allow(dead_code)] + // This isn't used because the prefix need to match without the windowing + // mecanism. This at worst is overallocating, not necessarily being wrong. + window_size: Option, + + block_size: u32, +} + +impl RadixAllocator { + pub fn new(block_size: u32, n_blocks: u32, window_size: Option) -> Self { + RadixAllocator { + allocation_id: 0, + allocations: HashMap::new(), + cache_blocks: RadixTrie::new(block_size as usize), + + // Block 0 is reserved for health checks. + free_blocks: (1..n_blocks).collect(), + window_size, + block_size, + } + } + + fn alloc_or_reclaim(&mut self, n_blocks_needed: usize) -> Option> { + if self.free_blocks.len() < n_blocks_needed { + // This is a bit annoying, we first extend the free list and then + // split it off again below. This is because we need to put it on + // the free list if we cannot allocate enough blocks. This is only + // temporary, the trie needs to be able to report whether it can + // allocate the requested amount. Just not implemented yet. + tracing::debug!( + "Free blocks {} need {n_blocks_needed}", + self.free_blocks.len() + ); + self.free_blocks.extend( + self.cache_blocks + .evict(n_blocks_needed - self.free_blocks.len()), + ); + } + + if self.free_blocks.len() >= n_blocks_needed { + Some( + self.free_blocks + .split_off(self.free_blocks.len() - n_blocks_needed), + ) + } else { + None + } + } +} + +// Allocator trait +impl Allocator for RadixAllocator { + fn allocate( + &mut self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option { + let mut blocks = vec![]; + let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() { + let node_id = self + .cache_blocks + .find(prefill_tokens.as_slice(), &mut blocks); + node_id + } else { + self.cache_blocks.root_id() + }; + + // Even if this allocation fails below, we need to increase he + // refcount to ensure that the prefix that was found is not evicted. + self.cache_blocks + .incref(prefix_node) + .expect("Failed to increment refcount"); + + let prefix_len = blocks.len() * self.block_size as usize; + let suffix_len = tokens - prefix_len as u32; + + let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size; + + tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}"); + + match self.alloc_or_reclaim(suffix_blocks as usize) { + Some(suffix_blocks) => blocks.extend(suffix_blocks), + None => { + tracing::debug!("Cannot allocate {:?}", self.cache_blocks); + tracing::debug!("Found {prefix_len} prefix tokens need {suffix_blocks} suffix blocks for {tokens} tokens"); + tracing::debug!("Block size {}", self.block_size); + self.cache_blocks + .decref(prefix_node) + .expect("Failed to decrement refcount"); + return None; + } + } + + // 1:1 mapping of blocks and slots. + let slots = if self.block_size == 1 { + blocks.clone() + } else { + let mut slots = Vec::with_capacity(blocks.len() * self.block_size as usize); + 'slots: for block_id in &blocks { + for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { + slots.push(s); + if slots.len() as u32 == tokens { + break 'slots; + } + } + } + slots + }; + + let allocation = RadixAllocation { + prefix_node, + cached_prefix_len: prefix_len, + prefill_tokens: prefill_tokens.clone(), + }; + + self.allocation_id += 1; + self.allocations.insert(self.allocation_id, allocation); + + Some(BlockAllocation { + allocation_id: self.allocation_id, + block_allocator: None, + blocks, + slots, + prefix_len: prefix_len as u32, + }) + } + + fn free(&mut self, blocks: Vec, allocation_id: u64) { + let allocation = match self.allocations.remove(&allocation_id) { + Some(allocation) => allocation, + None => unreachable!("Tried to free an unknown allocation."), + }; + + self.cache_blocks + .decref(allocation.prefix_node) + .expect("Failed to decrement refcount"); + + if let Some(prefill_tokens) = allocation.prefill_tokens { + let prefill_tokens = prefill_tokens.as_slice(); + + // If there are prefill tokens that did not come from the cache, + // add them to the cache. + if prefill_tokens.len() > allocation.cached_prefix_len { + let aligned = + (prefill_tokens.len() / self.block_size as usize) * self.block_size as usize; + if aligned > 0 { + let prefix_len = self + .cache_blocks + .insert( + &prefill_tokens[..aligned], + &blocks[..aligned / self.block_size as usize], + ) + // Unwrap, failing is a programming error. + .expect("Failed to store prefill tokens"); + // We can have a prefill with the following structure: + // + // |---| From the prefix cache. + // A B C D E F G + //|--------| Found in the trie during insertion. + // + // This means that while processing this request there was a + // partially overlapping request that had A..=E in its + // prefill. In this case we need to free the blocks D E. + if prefix_len > allocation.cached_prefix_len { + self.free_blocks.extend( + &blocks[allocation.cached_prefix_len / self.block_size as usize + ..prefix_len / self.block_size as usize], + ); + } + } + } + + // Free non-prefill blocks. + self.free_blocks + .extend(&blocks[prefill_tokens.len() / self.block_size as usize..]); + } else { + self.free_blocks.extend(blocks); + } + } +} + +struct RadixAllocation { + prefix_node: NodeId, + cached_prefix_len: usize, + prefill_tokens: Option>>, +} + +// Radix trie that is heavily inspired by radix attention from sglang. +// +// The trie is optimized for prefix caching: +// +// - A normal radix trie stores discrete values. In this radix trie, +// inserting *abc* with value *xyz* will also enable lookup for +// *a* (*x*) and *ab* (*xy*). +// - As a result, every value is required to have the same length as +// the key. +// - We store additional information in each node, such as last access +// time and a reference count. + +#[derive(Debug)] +pub enum TrieError { + InvalidNodeId, + RefCountUnderflow, +} + +pub type NodeId = DefaultKey; + +#[derive(Debug)] +pub struct RadixTrie { + /// Identifier of the root nod. + root: DefaultKey, + + /// Leave node identifiers ordered by increasing recency. + leaves: BTreeSet<(u64, NodeId)>, + + /// All trie nodes. + nodes: SlotMap, + + /// Time as a monotonically increating counter to avoid the system + /// call that a real time lookup would require. + time: u64, + + /// All blocks need to be aligned with this + block_size: usize, +} + +impl RadixTrie { + /// Construct a new radix trie. + pub fn new(block_size: usize) -> Self { + let root = TrieNode::new(vec![], vec![], 0, None); + let mut nodes = SlotMap::new(); + let root = nodes.insert(root); + RadixTrie { + leaves: BTreeSet::new(), + nodes, + root, + time: 0, + block_size, + } + } + + /// Find the prefix of the given tokens. + /// + /// The blocks corresponding to the part of the prefix that could be found + /// are written to `blocks`. The number of blocks is in `0..=tokens.len()`. + /// Returns the identifier of the trie node that contains the longest + /// prefix. The node identifier can be used by callers to e.g. increase its + /// reference count. + /// + /// Using this method will update the access time of the traversed nodes. + pub fn find(&mut self, key: &[u32], blocks: &mut Vec) -> NodeId { + self.time += 1; + self.find_(self.root, key, blocks) + } + + /// Find worker. + fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec) -> NodeId { + let node = &self.nodes[node_id]; + + if key.len() >= self.block_size { + let node_key = hash(&key[..self.block_size]); + if let Some(&child_id) = node.children.get(&node_key) { + self.update_access_time(child_id); + let child = self.nodes.get(child_id).expect("Invalid child identifier"); + let shared_prefix_len = shared_prefix(&child.key, key, self.block_size); + assert_eq!(shared_prefix_len % self.block_size, 0); + blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]); + + let key = &key[shared_prefix_len..]; + if !key.is_empty() { + node_id = self.find_(child_id, key, blocks); + } + } + } + + node_id + } + + /// Decrease the reference count of a node. + pub fn decref(&mut self, node_id: NodeId) -> Result<(), TrieError> { + // We don't care about refcounting for root, since it will never + // be evicted. + if node_id == self.root { + return Ok(()); + } + + let node = self + .nodes + .get_mut(node_id) + .ok_or(TrieError::InvalidNodeId)?; + if node.ref_count == 0 { + return Err(TrieError::RefCountUnderflow); + } + + node.ref_count -= 1; + if node.ref_count == 0 { + assert!( + node.children.is_empty(), + "Nodes with children must have refcount > 0" + ); + + self.leaves.insert((node.last_accessed, node_id)); + } + + Ok(()) + } + + /// Increase the reference count of a node. + pub fn incref(&mut self, node_id: NodeId) -> Result<(), TrieError> { + if node_id == self.root { + return Ok(()); + } + + let node = self + .nodes + .get_mut(node_id) + .ok_or(TrieError::InvalidNodeId)?; + if node.ref_count == 0 { + self.leaves.remove(&(node.last_accessed, node_id)); + } + node.ref_count += 1; + + Ok(()) + } + + /// Evict `n_blocks` from the trie. + /// + /// Returns the evicted blocks. When the length is less than `n_blocks`, + /// not enough blocks could be evicted. + pub fn evict(&mut self, n_blocks: usize) -> Vec { + // NOTE: we don't return Result here. If any of the unwrapping fails, + // it's a programming error in the trie implementation, not a user + // error caused by e.g. an invalid argument. + + // TODO: add some bookkeeping in the future to check whether we can + // evict n_blocks and return `None` if we can't. We are now needlessly + // evicting prefixes from the cache in such a case. + let mut evicted = Vec::new(); + tracing::debug!("Evicting in search of {n_blocks}"); + + while let Some((last_access, node_id)) = self.leaves.pop_first() { + let blocks_needed = n_blocks.saturating_sub(evicted.len()); + tracing::debug!("Evicting node {node_id:?} "); + + let node = self.nodes.get(node_id).expect("Leave does not exist"); + assert_eq!( + node.ref_count, 0, + "Leaf must have refcount of 0, got {}", + node.ref_count + ); + + if blocks_needed >= node.blocks.len() { + // We need to evict the whole node if we need more blocks than it has. + let node = self.remove_node(node_id); + evicted.extend(node.blocks); + + if evicted.len() >= n_blocks { + break; + } + } else { + // The node has more blocks than needed, so we'll just remove + // the required number of blocks and leave the remaining blocks + // untouched. + let node = self.nodes.get_mut(node_id).expect("Leave does not exist"); + + let truncate_blocks = node.blocks.len() - blocks_needed; + let truncate_tokens = truncate_blocks * self.block_size; + node.key.truncate(truncate_tokens); + evicted.extend(node.blocks.split_off(truncate_blocks)); + self.leaves.insert((last_access, node_id)); + break; + } + } + + evicted + } + + /// Insert a prefill along with its blocks. + /// + /// This method returns the length of the prefix that was already + /// in the trie. E.g. if the length is 10, this means that for + /// the first 10 elements of the tree **the blocks are not updated**. + pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result { + self.time += 1; + let common = self.insert_(self.root, tokens, blocks)?; + Ok(common) + } + + /// Insertion worker. + fn insert_( + &mut self, + node_id: NodeId, + tokens: &[u32], + blocks: &[u32], + ) -> Result { + // TODO: in the future we may want to check that the blocks match for + // the part of the prefix that is already in the trie to detect + // mismatches. + + assert_eq!(tokens.len(), blocks.len() * self.block_size); + + let node_key = hash(&tokens[..self.block_size]); + if let Some(&child_id) = self.nodes[node_id].children.get(&node_key) { + self.update_access_time(child_id); + let child = self + .nodes + .get_mut(child_id) + // Unwrap here, since failure is a bug. + .expect("Child node does not exist"); + let shared_prefix_len = shared_prefix(&child.key, tokens, self.block_size); + + // We are done, the prefix is already in the trie. + if shared_prefix_len == tokens.len() || shared_prefix_len == 0 { + return Ok(shared_prefix_len); + } + + // The node's prefix is a prefix of the insertion prefix. + if shared_prefix_len == child.key.len() { + return Ok(shared_prefix_len + + self.insert_( + child_id, + &tokens[shared_prefix_len..], + &blocks[shared_prefix_len / self.block_size..], + )?); + } + + // The node's prefix and the insertion prefix only match partially, + // split the node to just contain the matching part. Then insert the + // remainder of the prefix into the node again + let child_id = self.split_node(child_id, shared_prefix_len); + let key = &tokens[shared_prefix_len..]; + let blocks = &blocks[shared_prefix_len / self.block_size..]; + Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?) + } else { + self.add_node(node_id, tokens, blocks); + Ok(0) + } + } + + fn split_node(&mut self, node_id: NodeId, prefix_len: usize) -> NodeId { + // We have to make the current node a child to ensure that its + // properties and node id stay the same. + + // This funcion unwraps, an invalid node_id is a programming error. + + let node = self + .nodes + .get_mut(node_id) + .expect("Node to-be split does not exist"); + let mut parent_key = node.key.split_off(prefix_len); + let prefix_blocks = prefix_len / self.block_size; + let mut parent_blocks = node.blocks.split_off(prefix_blocks); + + // Move first part of the prefix to the parent. We swap to avoid + // an allocation + copy for both splits of the key/blocks. + std::mem::swap(&mut node.key, &mut parent_key); + std::mem::swap(&mut node.blocks, &mut parent_blocks); + + let node_key = hash(&node.key[..self.block_size]); + + let grandparent_id = node.parent.expect("Node does not have a parent"); + let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks); + self.add_node_to_parent(parent_id, node_key, node_id); + + // Reborrow to make the borrow checker happy. + let node = self + .nodes + .get_mut(node_id) + .expect("Node to-be split does not exist"); + node.parent = Some(parent_id); + + parent_id + } + + /// Create a node and add it to the parent. + fn add_node( + &mut self, + parent_id: NodeId, + key: impl Into>, + blocks: impl Into>, + ) -> NodeId { + let key = key.into(); + let blocks = blocks.into(); + let first = hash(&key[..self.block_size]); + + let child = TrieNode::new(key, blocks, self.time, Some(parent_id)); + let child_id = self.nodes.insert(child); + + self.add_node_to_parent(parent_id, first, child_id); + self.leaves.insert((self.time, child_id)); + + child_id + } + + /// Add a node to the parent. + fn add_node_to_parent(&mut self, parent_id: NodeId, hash: u64, child_id: NodeId) { + // Unwrap here, passing in an unknown id is a programming error. + let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); + if parent.children.insert(hash, child_id).is_none() { + // Only increase reference count if child does not replace another child. + self.incref(parent_id) + .expect("Failed to increase parent refcount"); + } + } + + /// Remove a node from the trie. + fn remove_node(&mut self, node_id: NodeId) -> TrieNode { + // Unwrap here, passing in an unknown id is a programming error. + let node = self.nodes.remove(node_id).expect("Unknown node"); + assert!( + node.children.is_empty(), + "Tried to remove a node with {} children", + node.children.len() + ); + let parent_id = node.parent.expect("Attempted to remove root node"); + let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); + + let node_key = hash(&node.key[..self.block_size]); + parent.children.remove(&node_key); + self.decref(parent_id) + .expect("Failed to decrease parent refcount"); + node + } + + fn update_access_time(&mut self, node_id: NodeId) { + // Unwrap here, passing in an unknown id is a programming error. + let node = self.nodes.get_mut(node_id).expect("Unknown node"); + + // Update the ordered leaves set if the node is a leave. + if self.leaves.remove(&(node.last_accessed, node_id)) { + self.leaves.insert((self.time, node_id)); + } + + node.last_accessed = self.time; + } + + #[allow(dead_code)] + #[doc(hidden)] + /// Print debugging output for the trie. + /// + /// In contrast to `Debug` nicely formatted. + pub fn print_debug(&self) { + self.print_debug_(self.root, 0); + } + + fn print_debug_(&self, node_id: NodeId, indent: usize) { + let node = &self.nodes[node_id]; + eprintln!( + "{}{:?}, key: {:?}, blocks: {:?}, ref_count: {}, last_accessed: {}, parent: {:?}, children: {:?}", + " ".repeat(indent), + node_id, + node.key, + node.blocks, + node.ref_count, + node.last_accessed, + node.parent, + node.children + ); + for child_id in self.nodes[node_id].children.values() { + self.print_debug_(*child_id, indent + 2); + } + } + + pub(crate) fn root_id(&self) -> DefaultKey { + self.root + } +} + +/// Trie node. +#[derive(Debug)] +struct TrieNode { + blocks: Vec, + children: HashMap, + key: Vec, + last_accessed: u64, + parent: Option, + ref_count: usize, +} + +impl TrieNode { + fn new(key: Vec, blocks: Vec, last_accessed: u64, parent: Option) -> Self { + TrieNode { + children: HashMap::new(), + key, + blocks, + last_accessed, + parent, + ref_count: 0, + } + } +} + +fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize { + let full = left.iter().zip(right).take_while(|(a, b)| a == b).count(); + // NOTE: this is the case because the child node was chosen based on + // matching the first character of the key/prefix. + assert!(full > 0, "Prefixes must at least share 1 token"); + (full / block_size) * block_size +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + + #[test] + fn allocator_block_size() { + let mut cache = RadixAllocator::new(2, 12, None); + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]); + assert_eq!(allocation.prefix_len, 4); + } + + #[test] + fn allocator_block_size_non_aligned() { + let mut cache = RadixAllocator::new(2, 12, None); + let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]); + assert_eq!(allocation.prefix_len, 2); + } + + #[test] + fn allocator_reuses_prefixes() { + let mut cache = RadixAllocator::new(1, 12, None); + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation.blocks, allocation.slots); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation.prefix_len, 4); + } + + #[test] + fn allocator_collects_older_prefixes_first() { + let mut cache = RadixAllocator::new(1, 7, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]); + assert_eq!(allocation1.prefix_len, 0); + + let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap(); + assert_eq!(allocation2.blocks, vec![1, 2]); + assert_eq!(allocation2.prefix_len, 0); + + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + + // We should get the blocks of the first allocation, since they are more recent. + let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap(); + assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]); + assert_eq!(allocation3.prefix_len, 0); + } + + #[test] + fn allocator_frees_fully_overlapping_prefills() { + let mut cache = RadixAllocator::new(1, 10, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + + let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation3.prefix_len, 4); + + // 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks. + assert_eq!(cache.free_blocks.len(), 5); + } + + #[test] + fn allocator_frees_partially_overlapping_prefills() { + let mut cache = RadixAllocator::new(1, 20, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap(); + assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]); + assert_eq!(allocation1.prefix_len, 0); + + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + + let allocation2 = cache + .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) + .unwrap(); + assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]); + assert_eq!(allocation2.prefix_len, 2); + + let allocation3 = cache + .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) + .unwrap(); + assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation3.prefix_len, 2); + + cache.free(allocation3.blocks.clone(), allocation3.allocation_id); + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + + // 20 blocks, of which 1 reserved for health checks, 6 for allocation3, 2 for allocation2. + assert_eq!(cache.free_blocks.len(), 11); + + let allocation4 = cache + .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) + .unwrap(); + assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]); + assert_eq!(allocation4.prefix_len, 6); + assert_eq!(cache.free_blocks.len(), 11); + + let allocation5 = cache + .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) + .unwrap(); + assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]); + assert_eq!(allocation5.prefix_len, 6); + assert_eq!(cache.free_blocks.len(), 11); + } + + #[test] + fn trie_insertions_have_correct_prefix_len() { + let mut trie = RadixTrie::new(1); + + assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0); + + // Already exists. + assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 3); + + // Completely new at root-level + assert_eq!(trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(), 0); + + // Contains full prefix, but longer. + assert_eq!(trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(), 3); + + // Shares partial prefix, we need a split. + assert_eq!( + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + .unwrap(), + 4 + ); + } + + #[test] + fn trie_insertions_block_size() { + let mut trie = RadixTrie::new(2); + + assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 0); + + // Already exists. + // But needs to be block_size aligned + assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 4); + + // Completely new at root-level + assert_eq!(trie.insert(&[1, 2, 3, 4], &[1, 2]).unwrap(), 0); + + // Contains full prefix, but longer. + assert_eq!(trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2]).unwrap(), 4); + + // Shares partial prefix, we need a split. + assert_eq!( + trie.insert(&[0, 1, 3, 4, 5, 6, 7, 8], &[0, 1, 2, 3]) + .unwrap(), + 2 + ); + } + + #[test] + fn trie_get_returns_correct_blocks() { + let mut trie = RadixTrie::new(1); + trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); + trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(); + trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + .unwrap(); + + let mut blocks = Vec::new(); + trie.find(&[0], &mut blocks); + assert_eq!(blocks, vec![0]); + + blocks.clear(); + trie.find(&[0, 1, 2], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2]); + + blocks.clear(); + trie.find(&[1, 2, 3], &mut blocks); + assert_eq!(blocks, vec![1, 2, 3]); + + blocks.clear(); + trie.find(&[0, 1, 2, 3], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3]); + + blocks.clear(); + trie.find(&[0, 1, 2, 3, 4], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 4]); + + blocks.clear(); + trie.find(&[0, 1, 2, 3, 5], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 5]); + } + + #[test] + fn trie_evict_removes_correct_blocks() { + let mut trie = RadixTrie::new(1); + trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + .unwrap(); + trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); + trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(); + + let mut blocks = Vec::new(); + + // Remove less than the leave blocks. + assert_eq!(trie.evict(1), vec![7]); + trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 5, 6]); + + // Refresh other leaf. + trie.find(&[0, 1, 2, 3, 4], &mut blocks); + trie.find(&[1, 2, 3], &mut blocks); + + // Remove the leave blocks exactly. + assert_eq!(trie.evict(2), vec![5, 6]); + blocks.clear(); + trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3]); + + trie.find(&[1, 2, 3], &mut blocks); + + // Remove more than the leave blocks. + assert_eq!(trie.evict(3), vec![4, 3, 2]); + blocks.clear(); + trie.find(&[0, 1, 2, 3, 4], &mut blocks); + assert_eq!(blocks, vec![0, 1]); + + // Clear out the whole trie. + assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]); + } +} diff --git a/benchmark/Cargo.toml b/benchmark/Cargo.toml index 756460e0a3caa6bf3bdee03085a0a36a39189482..23885da29f91e79fbb535750480eb00a0065c980 100644 --- a/benchmark/Cargo.toml +++ b/benchmark/Cargo.toml @@ -16,16 +16,15 @@ path = "src/main.rs" [dependencies] average = "0.14" clap = { version = "4.4.5", features = ["derive", "env"] } -crossterm = "0.27" float-ord = "0.3.2" serde = {version = "1.0.188", features = ["derive"]} serde_json = "1.0" tabled = "0.14.0" -text-generation-client = { path = "../router/client" } +text-generation-client = { path = "../backends/client" } thiserror = "1.0.48" tokenizers = { workspace = true } tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] } -tui = {package = "ratatui", version = "0.23", default-features = false, features = ["crossterm"]} +ratatui = "0.28.1" tracing = "0.1.37" tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } hf-hub = { workspace = true } diff --git a/benchmark/README.md b/benchmark/README.md index 17a02a30c641d80f8fd836c0cdd62cdba9df5d46..f4e0cb168c562a65ebefdd7ab9371d2c81b210ee 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -7,7 +7,7 @@ A lightweight benchmarking tool based inspired by [oha](https://github.com/hatoo/oha) -and powered by [tui](https://github.com/tui-rs-revival/ratatui). +and powered by [Ratatui](https://github.com/ratatui/ratatui). ## Install diff --git a/benchmark/src/app.rs b/benchmark/src/app.rs index a0a9313a198fc052471aeb2d4eda673574d66b8a..7e3aeaf9ecb9dd44b79d0fa3d729d5d024a626b9 100644 --- a/benchmark/src/app.rs +++ b/benchmark/src/app.rs @@ -1,16 +1,15 @@ /// Inspired by https://github.com/hatoo/oha/blob/bb989ea3cd77727e7743e7daa60a19894bb5e901/src/monitor.rs use crate::generation::{Decode, Message, Prefill}; -use crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; -use text_generation_client::ClientError; -use tokio::sync::mpsc; -use tui::backend::Backend; -use tui::layout::{Alignment, Constraint, Direction, Layout}; -use tui::style::{Color, Modifier, Style}; -use tui::text::{Line, Span}; -use tui::widgets::{ +use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; +use ratatui::layout::{Alignment, Constraint, Direction, Layout}; +use ratatui::style::{Color, Modifier, Style}; +use ratatui::text::{Line, Span}; +use ratatui::widgets::{ Axis, BarChart, Block, Borders, Chart, Dataset, Gauge, GraphType, Paragraph, Tabs, }; -use tui::{symbols, Frame}; +use ratatui::{symbols, Frame}; +use text_generation_client::ClientError; +use tokio::sync::mpsc; /// TUI powered App pub(crate) struct App { @@ -153,7 +152,7 @@ impl App { } /// Render frame - pub fn render(&mut self, f: &mut Frame<'_, B>) { + pub fn render(&mut self, f: &mut Frame) { let batch_progress = (self.completed_batch as f64 / self.data.batch_size.len() as f64).clamp(0.0, 1.0); let run_progress = @@ -172,7 +171,7 @@ impl App { ] .as_ref(), ) - .split(f.size()); + .split(f.area()); // Top row horizontal layout let top = Layout::default() @@ -239,7 +238,7 @@ impl App { f.render_widget(helper, row5[0]); // Batch tabs - let titles = self + let titles: Vec = self .data .batch_size .iter() diff --git a/benchmark/src/event.rs b/benchmark/src/event.rs index 07482aed8de507d0a79f59cea418f017f88fef9e..d3f10fb6fc6bf6a1531acfa12784c07fbcb5c13d 100644 --- a/benchmark/src/event.rs +++ b/benchmark/src/event.rs @@ -1,5 +1,5 @@ /// Inspired by https://github.com/orhun/rust-tui-template/blob/472aa515119d4c94903eac12d9784417281dc7f5/src/event.rs -use crossterm::event; +use ratatui::crossterm::event; use std::time::{Duration, Instant}; use tokio::sync::{broadcast, mpsc}; diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 5e739703fbf36cd7e8d718b4f66eba4aed5cfc8e..63fc780818ba6cc1c76b77f1f9f9793eaf810df2 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -148,6 +148,7 @@ async fn prefill( }), inputs: sequence.clone(), truncate: sequence_length, + add_special_tokens: true, parameters: Some(parameters.clone()), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: decode_length, @@ -157,6 +158,8 @@ async fn prefill( top_n_tokens: top_n_tokens.unwrap_or(0), blocks: vec![], slots: vec![], + cache_len: 0, + chunk_len: None, adapter_id: None, }) .collect(); @@ -171,7 +174,7 @@ async fn prefill( // Run prefill let start_time = Instant::now(); - let (_, decode_batch, _) = client.prefill(batch.clone()).await?; + let (_, decode_batch, _) = client.prefill(batch.clone(), None).await?; // Get latency let latency = start_time.elapsed(); diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index c33d64e673ef31b39fab8af66061c7acfc2ad64a..bb4b6a77f9f8d50769997a23fba6755c74a725c1 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -6,13 +6,13 @@ mod utils; use crate::app::App; use crate::event::Event; -use crossterm::ExecutableCommand; +use ratatui::backend::CrosstermBackend; +use ratatui::crossterm::ExecutableCommand; +use ratatui::Terminal; use std::io; use text_generation_client::v3::{GrammarType, NextTokenChooserParameters, ShardedClient}; use tokenizers::Tokenizer; use tokio::sync::{broadcast, mpsc}; -use tui::backend::CrosstermBackend; -use tui::Terminal; /// Run benchmarking app #[allow(clippy::too_many_arguments)] @@ -50,9 +50,9 @@ pub async fn run( }; // Initialize terminal properties - crossterm::terminal::enable_raw_mode()?; - io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?; - io::stdout().execute(crossterm::cursor::Hide)?; + ratatui::crossterm::terminal::enable_raw_mode()?; + io::stdout().execute(ratatui::crossterm::terminal::EnterAlternateScreen)?; + io::stdout().execute(ratatui::crossterm::cursor::Hide)?; // Initialize terminal let mut terminal = { @@ -128,9 +128,9 @@ pub async fn run( let _ = shutdown_guard_receiver.recv().await; // Revert terminal to original view - io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?; - crossterm::terminal::disable_raw_mode()?; - io::stdout().execute(crossterm::cursor::Show)?; + io::stdout().execute(ratatui::crossterm::terminal::LeaveAlternateScreen)?; + ratatui::crossterm::terminal::disable_raw_mode()?; + io::stdout().execute(ratatui::crossterm::cursor::Show)?; let parameters_table = table::parameters_table( tokenizer_name, diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index 2ee3d7c551a41169df1ce86a83ce826ad644f011..2e2e9a11c12860fed0d1f3ae1e9e645ac20ab573 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -178,6 +178,7 @@ fn main() -> Result<(), Box> { .clear_cache(None) .await .expect("Unable to clear cache"); + tracing::info!("Connected"); // Run app diff --git a/clients/python/README.md b/clients/python/README.md index bf37508e070dd4b5238a8886dd0f1652dd046bae..88239aa16cc65f4ce6cf3510eacabd8746af2184 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -1,3 +1,6 @@ +# Legacy warning ⚠️ +The inference clients from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference) are recommended over `text_generation`. + # Text Generation The Hugging Face Text Generation Python library provides a convenient way of interfacing with a diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 2925085b6a4d2b5a446ffce0f22d17d5f0ba434f..47ef9d717844a51b98241bfe597d9358f4c7e86a 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -27,3 +27,6 @@ asyncio_mode = "auto" [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" + +[tool.isort] +profile = "black" diff --git a/clients/python/tests_native/conftest.py b/clients/python/tests_native/conftest.py deleted file mode 100644 index 954787d871ccb6a0129030398a1397fcfee17e2b..0000000000000000000000000000000000000000 --- a/clients/python/tests_native/conftest.py +++ /dev/null @@ -1,64 +0,0 @@ -import pytest - -from text_generation import __version__ -from huggingface_hub.utils import build_hf_headers - - -@pytest.fixture -def flan_t5_xxl(): - return "google/flan-t5-xxl" - - -@pytest.fixture -def llama_7b(): - return "meta-llama/Llama-2-7b-chat-hf" - - -@pytest.fixture -def fake_model(): - return "fake/model" - - -@pytest.fixture -def unsupported_model(): - return "gpt2" - - -@pytest.fixture -def base_url(): - return "https://api-inference.huggingface.co/models" - - -@pytest.fixture -def bloom_url(base_url, bloom_model): - return f"{base_url}/{bloom_model}" - - -@pytest.fixture -def flan_t5_xxl_url(base_url, flan_t5_xxl): - return f"{base_url}/{flan_t5_xxl}" - - -@pytest.fixture -def llama_7b_url(base_url, llama_7b): - # return f"{base_url}/{llama_7b}" - return "http://localhost:3001" - - -@pytest.fixture -def fake_url(base_url, fake_model): - return f"{base_url}/{fake_model}" - - -@pytest.fixture -def unsupported_url(base_url, unsupported_model): - return f"{base_url}/{unsupported_model}" - - -@pytest.fixture(scope="session") -def hf_headers(): - # return build_hf_headers( - # library_name="text-generation-tests", library_version=__version__ - # ) - header = {'content-type': 'application/json'} - return header diff --git a/clients/python/tests_native/test_client.py b/clients/python/tests_native/test_client.py deleted file mode 100644 index 9c6cb962bad3e1f7b7ff3d0aced855001ed98ce8..0000000000000000000000000000000000000000 --- a/clients/python/tests_native/test_client.py +++ /dev/null @@ -1,125 +0,0 @@ -import pytest - -from text_generation import Client, AsyncClient -from text_generation.errors import NotFoundError, ValidationError -from text_generation.types import FinishReason, InputToken - - -def test_generate(llama_7b_url, hf_headers): - client = Client(llama_7b_url, hf_headers) - response = client.generate("test", max_new_tokens=1, decoder_input_details=True) - - assert response.generated_text == "_" - assert response.details.finish_reason == FinishReason.Length - assert response.details.generated_tokens == 1 - assert response.details.seed is None - assert len(response.details.prefill) == 2 - assert response.details.prefill[0] == InputToken(id=1, text="", logprob=None) - assert len(response.details.tokens) == 1 - assert response.details.tokens[0].id == 29918 - assert response.details.tokens[0].text == "_" - assert not response.details.tokens[0].special - - -def test_generate_best_of(llama_7b_url, hf_headers): - client = Client(llama_7b_url, hf_headers) - response = client.generate( - "test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True - ) - - assert response.details.seed is not None - assert response.details.best_of_sequences is not None - assert len(response.details.best_of_sequences) == 1 - assert response.details.best_of_sequences[0].seed is not None - -def test_generate_validation_error(llama_7b_url, hf_headers): - client = Client(llama_7b_url, hf_headers) - with pytest.raises(ValidationError): - client.generate("test", max_new_tokens=10_000) - - -def test_generate_stream(llama_7b_url, hf_headers): - client = Client(llama_7b_url, hf_headers) - responses = [ - response for response in client.generate_stream("test", max_new_tokens=1) - ] - - assert len(responses) == 1 - response = responses[0] - - assert response.generated_text == "_" - assert response.details.finish_reason == FinishReason.Length - assert response.details.generated_tokens == 1 - assert response.details.seed is None - - -def test_generate_stream_validation_error(llama_7b_url, hf_headers): - client = Client(llama_7b_url, hf_headers) - with pytest.raises(ValidationError): - list(client.generate_stream("test", max_new_tokens=10_000)) - - -@pytest.mark.asyncio -async def test_generate_async(llama_7b_url, hf_headers): - client = AsyncClient(llama_7b_url, hf_headers) - response = await client.generate( - "test", max_new_tokens=1, decoder_input_details=True - ) - - assert response.generated_text == "_" - assert response.details.finish_reason == FinishReason.Length - assert response.details.generated_tokens == 1 - assert response.details.seed is None - assert len(response.details.prefill) == 2 - assert response.details.prefill[0] == InputToken(id=1, text="", logprob=None) - assert response.details.prefill[1] == InputToken( - id=1243, text="test", logprob=-10.9375 - ) - assert len(response.details.tokens) == 1 - assert response.details.tokens[0].id == 29918 - assert response.details.tokens[0].text == "_" - assert not response.details.tokens[0].special - - -@pytest.mark.asyncio -async def test_generate_async_best_of(llama_7b_url, hf_headers): - client = AsyncClient(llama_7b_url, hf_headers) - response = await client.generate( - "test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True - ) - - assert response.details.seed is not None - assert response.details.best_of_sequences is not None - assert len(response.details.best_of_sequences) == 1 - assert response.details.best_of_sequences[0].seed is not None - - -@pytest.mark.asyncio -async def test_generate_async_validation_error(llama_7b_url, hf_headers): - client = AsyncClient(llama_7b_url, hf_headers) - with pytest.raises(ValidationError): - await client.generate("test", max_new_tokens=10_000) - - -@pytest.mark.asyncio -async def test_generate_stream_async(llama_7b_url, hf_headers): - client = AsyncClient(llama_7b_url, hf_headers) - responses = [ - response async for response in client.generate_stream("test", max_new_tokens=1) - ] - - assert len(responses) == 1 - response = responses[0] - - assert response.generated_text == "_" - assert response.details.finish_reason == FinishReason.Length - assert response.details.generated_tokens == 1 - assert response.details.seed is None - - -@pytest.mark.asyncio -async def test_generate_stream_async_validation_error(llama_7b_url, hf_headers): - client = AsyncClient(llama_7b_url, hf_headers) - with pytest.raises(ValidationError): - async for _ in client.generate_stream("test", max_new_tokens=10_000): - pass diff --git a/clients/python/tests_native/test_errors.py b/clients/python/tests_native/test_errors.py deleted file mode 100644 index 8389ed313812a8e56bef21a5b4437d55911469f3..0000000000000000000000000000000000000000 --- a/clients/python/tests_native/test_errors.py +++ /dev/null @@ -1,64 +0,0 @@ -from text_generation.errors import ( - parse_error, - GenerationError, - IncompleteGenerationError, - OverloadedError, - ValidationError, - BadRequestError, - ShardNotReadyError, - ShardTimeoutError, - NotFoundError, - RateLimitExceededError, - UnknownError, -) - - -def test_generation_error(): - payload = {"error_type": "generation", "error": "test"} - assert isinstance(parse_error(400, payload), GenerationError) - - -def test_incomplete_generation_error(): - payload = {"error_type": "incomplete_generation", "error": "test"} - assert isinstance(parse_error(400, payload), IncompleteGenerationError) - - -def test_overloaded_error(): - payload = {"error_type": "overloaded", "error": "test"} - assert isinstance(parse_error(400, payload), OverloadedError) - - -def test_validation_error(): - payload = {"error_type": "validation", "error": "test"} - assert isinstance(parse_error(400, payload), ValidationError) - - -def test_bad_request_error(): - payload = {"error": "test"} - assert isinstance(parse_error(400, payload), BadRequestError) - - -def test_shard_not_ready_error(): - payload = {"error": "test"} - assert isinstance(parse_error(403, payload), ShardNotReadyError) - assert isinstance(parse_error(424, payload), ShardNotReadyError) - - -def test_shard_timeout_error(): - payload = {"error": "test"} - assert isinstance(parse_error(504, payload), ShardTimeoutError) - - -def test_not_found_error(): - payload = {"error": "test"} - assert isinstance(parse_error(404, payload), NotFoundError) - - -def test_rate_limit_exceeded_error(): - payload = {"error": "test"} - assert isinstance(parse_error(429, payload), RateLimitExceededError) - - -def test_unknown_error(): - payload = {"error": "test"} - assert isinstance(parse_error(500, payload), UnknownError) diff --git a/clients/python/tests_native/test_types.py b/clients/python/tests_native/test_types.py deleted file mode 100644 index 77689ade7c9f9f0ac0ca8f3772e0d71ce31b7b0a..0000000000000000000000000000000000000000 --- a/clients/python/tests_native/test_types.py +++ /dev/null @@ -1,84 +0,0 @@ -import pytest - -from text_generation.types import Parameters, Request -from text_generation.errors import ValidationError - - -def test_parameters_validation(): - # Test best_of - Parameters(best_of=1) - with pytest.raises(ValidationError): - Parameters(best_of=0) - with pytest.raises(ValidationError): - Parameters(best_of=-1) - Parameters(best_of=2, do_sample=True) - with pytest.raises(ValidationError): - Parameters(best_of=2) - with pytest.raises(ValidationError): - Parameters(best_of=2, seed=1) - - # Test repetition_penalty - Parameters(repetition_penalty=1) - with pytest.raises(ValidationError): - Parameters(repetition_penalty=0) - with pytest.raises(ValidationError): - Parameters(repetition_penalty=-1) - - # Test seed - Parameters(seed=1) - with pytest.raises(ValidationError): - Parameters(seed=-1) - - # Test temperature - Parameters(temperature=1) - with pytest.raises(ValidationError): - Parameters(temperature=0) - with pytest.raises(ValidationError): - Parameters(temperature=-1) - - # Test top_k - Parameters(top_k=1) - with pytest.raises(ValidationError): - Parameters(top_k=0) - with pytest.raises(ValidationError): - Parameters(top_k=-1) - - # Test top_p - Parameters(top_p=0.5) - with pytest.raises(ValidationError): - Parameters(top_p=0) - with pytest.raises(ValidationError): - Parameters(top_p=-1) - with pytest.raises(ValidationError): - Parameters(top_p=1) - - # Test truncate - Parameters(truncate=1) - with pytest.raises(ValidationError): - Parameters(truncate=0) - with pytest.raises(ValidationError): - Parameters(truncate=-1) - - # Test typical_p - Parameters(typical_p=0.5) - with pytest.raises(ValidationError): - Parameters(typical_p=0) - with pytest.raises(ValidationError): - Parameters(typical_p=-1) - with pytest.raises(ValidationError): - Parameters(typical_p=1) - - -def test_request_validation(): - Request(inputs="test") - - with pytest.raises(ValidationError): - Request(inputs="") - - Request(inputs="test", stream=True) - Request(inputs="test", parameters=Parameters(best_of=2, do_sample=True)) - - with pytest.raises(ValidationError): - Request( - inputs="test", parameters=Parameters(best_of=2, do_sample=True), stream=True - ) diff --git a/clients/python/text_generation/__init__.py b/clients/python/text_generation/__init__.py index d7a09c9ebf6eb01ec4463c44dc9a971236408b17..ca783dcdfa7bd8eeebf4c57c0aa3d0a7e22dfbd2 100644 --- a/clients/python/text_generation/__init__.py +++ b/clients/python/text_generation/__init__.py @@ -19,5 +19,15 @@ DEPRECATION_WARNING = ( "Please use the `InferenceClient` from the `huggingface_hub` package instead." ) -from text_generation.client import Client, AsyncClient -from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient +from text_generation.client import Client, AsyncClient # noqa E402 +from text_generation.inference_api import ( # noqa E402 + InferenceAPIClient, + InferenceAPIAsyncClient, +) + +__all__ = [ + "Client", + "AsyncClient", + "InferenceAPIClient", + "InferenceAPIAsyncClient", +] diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 129667472f7da83d55cecc2582c92924f4efcbb9..45301b6325664d2906b390c3fdb187d409b12766 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -757,7 +757,12 @@ class AsyncClient: continue payload = byte_payload.decode("utf-8") if payload.startswith("data:"): - json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) + payload_data = ( + payload.lstrip("data:").rstrip("\n").removeprefix(" ") + ) + if payload_data == "[DONE]": + break + json_payload = json.loads(payload_data) try: response = ChatCompletionChunk(**json_payload) yield response diff --git a/clients/python/text_generation/inference_api.py b/clients/python/text_generation/inference_api.py index 93b0de8d42c81f2bf2b428239e71eda6f37b3c24..b3b98ed28bf1ac1b948c6b012f6dd3c1ed3c6aa7 100644 --- a/clients/python/text_generation/inference_api.py +++ b/clients/python/text_generation/inference_api.py @@ -21,7 +21,7 @@ def deployed_models(headers: Optional[Dict] = None) -> List[DeployedModel]: List[DeployedModel]: list of all currently deployed models """ resp = requests.get( - f"https://api-inference.huggingface.co/framework/text-generation-inference", + "https://api-inference.huggingface.co/framework/text-generation-inference", headers=headers, timeout=5, ) diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index a56edaca75909cb52e0078830d7ee57f2ee4f384..1085075e43803bed472a082ce64759860dfc4626 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -28,11 +28,17 @@ class ToolCall(BaseModel): function: dict +class Chunk(BaseModel): + type: str + text: Optional[str] = None + image_url: Any = None + + class Message(BaseModel): # Role of the message sender role: str # Content of the message - content: Optional[str] = None + content: Optional[Union[str, List[Chunk]]] = None # Optional name of the message sender name: Optional[str] = None # Tool calls associated with the chat completion @@ -61,7 +67,7 @@ class ChoiceDeltaToolCall(BaseModel): class ChoiceDelta(BaseModel): role: str content: Optional[str] = None - tool_calls: Optional[ChoiceDeltaToolCall] + tool_calls: Optional[ChoiceDeltaToolCall] = None class Choice(BaseModel): @@ -168,7 +174,7 @@ class ChatCompletionComplete(BaseModel): # Log probabilities for the chat completion logprobs: Optional[Any] # Reason for completion - finish_reason: str + finish_reason: Optional[str] # Usage details of the chat completion usage: Optional[Any] = None @@ -191,6 +197,7 @@ class ChatCompletionChunk(BaseModel): model: str system_fingerprint: str choices: List[Choice] + usage: Optional[Any] = None class Parameters(BaseModel): diff --git a/docs/openapi.json b/docs/openapi.json index 03e2d4ff26287e1de734590c60cfe6231b1f625f..7eca49637ae3868962e95ea8f86ce999bc7c4331 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -10,7 +10,7 @@ "name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0" }, - "version": "2.1.1" + "version": "2.4.0" }, "paths": { "/": { @@ -316,6 +316,98 @@ } } }, + "/invocations": { + "post": { + "tags": [ + "Text Generation Inference" + ], + "summary": "Generate tokens from Sagemaker request", + "operationId": "sagemaker_compatibility", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SagemakerRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Generated Chat Completion", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SagemakerResponse" + } + }, + "text/event-stream": { + "schema": { + "$ref": "#/components/schemas/SagemakerStreamResponse" + } + } + } + }, + "422": { + "description": "Input validation error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Input validation error", + "error_type": "validation" + } + } + } + }, + "424": { + "description": "Generation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Request failed during generation", + "error_type": "generation" + } + } + } + }, + "429": { + "description": "Model is overloaded", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Model is overloaded", + "error_type": "overloaded" + } + } + } + }, + "500": { + "description": "Incomplete generation", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Incomplete generation", + "error_type": "incomplete_generation" + } + } + } + } + } + } + }, "/metrics": { "get": { "tags": [ @@ -492,12 +584,12 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/Completion" + "$ref": "#/components/schemas/CompletionFinal" } }, "text/event-stream": { "schema": { - "$ref": "#/components/schemas/CompletionCompleteChunk" + "$ref": "#/components/schemas/Chunk" } } } @@ -556,6 +648,37 @@ } } } + }, + "/v1/models": { + "get": { + "tags": [ + "Text Generation Inference" + ], + "summary": "Get model info", + "operationId": "openai_get_model_info", + "responses": { + "200": { + "description": "Served model info", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ModelInfo" + } + } + } + }, + "404": { + "description": "Model not found", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + } + } + } + } + } + } } }, "components": { @@ -711,6 +834,14 @@ }, "system_fingerprint": { "type": "string" + }, + "usage": { + "allOf": [ + { + "$ref": "#/components/schemas/Usage" + } + ], + "nullable": true } } }, @@ -809,7 +940,6 @@ "ChatRequest": { "type": "object", "required": [ - "model", "messages" ], "properties": { @@ -820,6 +950,13 @@ "example": "1.0", "nullable": true }, + "guideline": { + "type": "string", + "description": "A guideline to be used in the chat_template", + "default": "null", + "example": "null", + "nullable": true + }, "logit_bias": { "type": "array", "items": { @@ -854,7 +991,8 @@ "model": { "type": "string", "description": "[UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", - "example": "mistralai/Mistral-7B-Instruct-v0.2" + "example": "mistralai/Mistral-7B-Instruct-v0.2", + "nullable": true }, "n": { "type": "integer", @@ -899,6 +1037,14 @@ "stream": { "type": "boolean" }, + "stream_options": { + "allOf": [ + { + "$ref": "#/components/schemas/StreamOptions" + } + ], + "nullable": true + }, "temperature": { "type": "number", "format": "float", @@ -909,7 +1055,7 @@ "tool_choice": { "allOf": [ { - "$ref": "#/components/schemas/ToolType" + "$ref": "#/components/schemas/ToolChoice" } ], "nullable": true @@ -917,7 +1063,7 @@ "tool_prompt": { "type": "string", "description": "A prompt to be appended before the tools", - "example": "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"", + "example": "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.", "nullable": true }, "tools": { @@ -1116,7 +1262,6 @@ "CompletionRequest": { "type": "object", "required": [ - "model", "prompt" ], "properties": { @@ -1138,7 +1283,8 @@ "model": { "type": "string", "description": "UNUSED\nID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", - "example": "mistralai/Mistral-7B-Instruct-v0.2" + "example": "mistralai/Mistral-7B-Instruct-v0.2", + "nullable": true }, "prompt": { "$ref": "#/components/schemas/Prompt" @@ -1324,6 +1470,17 @@ } } }, + "FunctionName": { + "type": "object", + "required": [ + "name" + ], + "properties": { + "name": { + "type": "string" + } + } + }, "GenerateParameters": { "type": "object", "properties": { @@ -1569,16 +1726,11 @@ "type": "object", "required": [ "model_id", - "model_dtype", - "model_device_type", "max_concurrent_requests", "max_best_of", "max_stop_sequences", "max_input_tokens", "max_total_tokens", - "waiting_served_ratio", - "max_batch_total_tokens", - "max_waiting_tokens", "validation_workers", "max_client_batch_size", "router", @@ -1590,18 +1742,6 @@ "example": "null", "nullable": true }, - "max_batch_size": { - "type": "integer", - "example": "null", - "nullable": true, - "minimum": 0 - }, - "max_batch_total_tokens": { - "type": "integer", - "format": "int32", - "example": "32000", - "minimum": 0 - }, "max_best_of": { "type": "integer", "example": "2", @@ -1633,19 +1773,6 @@ "example": "2048", "minimum": 0 }, - "max_waiting_tokens": { - "type": "integer", - "example": "20", - "minimum": 0 - }, - "model_device_type": { - "type": "string", - "example": "cuda" - }, - "model_dtype": { - "type": "string", - "example": "torch.float16" - }, "model_id": { "type": "string", "description": "Model info", @@ -1679,11 +1806,6 @@ "version": { "type": "string", "example": "0.5.0" - }, - "waiting_served_ratio": { - "type": "number", - "format": "float", - "example": "1.2" } } }, @@ -1708,6 +1830,101 @@ } } }, + "MessageChunk": { + "oneOf": [ + { + "type": "object", + "required": [ + "text", + "type" + ], + "properties": { + "text": { + "type": "string" + }, + "type": { + "type": "string", + "enum": [ + "text" + ] + } + } + }, + { + "type": "object", + "required": [ + "image_url", + "type" + ], + "properties": { + "image_url": { + "$ref": "#/components/schemas/Url" + }, + "type": { + "type": "string", + "enum": [ + "image_url" + ] + } + } + } + ], + "discriminator": { + "propertyName": "type" + } + }, + "MessageContent": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "$ref": "#/components/schemas/MessageChunk" + } + } + ] + }, + "ModelInfo": { + "type": "object", + "required": [ + "id", + "object", + "created", + "owned_by" + ], + "properties": { + "created": { + "type": "integer", + "format": "int64", + "example": 1686935002, + "minimum": 0 + }, + "id": { + "type": "string", + "example": "gpt2" + }, + "object": { + "type": "string", + "example": "model" + }, + "owned_by": { + "type": "string", + "example": "openai" + } + } + }, + "OutputMessage": { + "oneOf": [ + { + "$ref": "#/components/schemas/TextMessage" + }, + { + "$ref": "#/components/schemas/ToolCallMessage" + } + ] + }, "PrefillToken": { "type": "object", "required": [ @@ -1740,6 +1957,45 @@ "type": "string" } }, + "SagemakerRequest": { + "oneOf": [ + { + "$ref": "#/components/schemas/CompatGenerateRequest" + }, + { + "$ref": "#/components/schemas/ChatRequest" + }, + { + "$ref": "#/components/schemas/CompletionRequest" + } + ] + }, + "SagemakerResponse": { + "oneOf": [ + { + "$ref": "#/components/schemas/GenerateResponse" + }, + { + "$ref": "#/components/schemas/ChatCompletion" + }, + { + "$ref": "#/components/schemas/CompletionFinal" + } + ] + }, + "SagemakerStreamResponse": { + "oneOf": [ + { + "$ref": "#/components/schemas/StreamResponse" + }, + { + "$ref": "#/components/schemas/ChatCompletionChunk" + }, + { + "$ref": "#/components/schemas/Chunk" + } + ] + }, "SimpleToken": { "type": "object", "required": [ @@ -1775,7 +2031,8 @@ "type": "object", "required": [ "finish_reason", - "generated_tokens" + "generated_tokens", + "input_length" ], "properties": { "finish_reason": { @@ -1787,6 +2044,12 @@ "example": 1, "minimum": 0 }, + "input_length": { + "type": "integer", + "format": "int32", + "example": 1, + "minimum": 0 + }, "seed": { "type": "integer", "format": "int64", @@ -1796,6 +2059,19 @@ } } }, + "StreamOptions": { + "type": "object", + "required": [ + "include_usage" + ], + "properties": { + "include_usage": { + "type": "boolean", + "description": "If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.", + "example": "true" + } + } + }, "StreamResponse": { "type": "object", "required": [ @@ -1834,6 +2110,23 @@ } } }, + "TextMessage": { + "type": "object", + "required": [ + "role", + "content" + ], + "properties": { + "content": { + "type": "string", + "example": "My name is David and I" + }, + "role": { + "type": "string", + "example": "user" + } + } + }, "Token": { "type": "object", "required": [ @@ -1906,15 +2199,64 @@ } } }, + "ToolCallDelta": { + "type": "object", + "required": [ + "role", + "tool_calls" + ], + "properties": { + "role": { + "type": "string", + "example": "assistant" + }, + "tool_calls": { + "$ref": "#/components/schemas/DeltaToolCall" + } + } + }, + "ToolCallMessage": { + "type": "object", + "required": [ + "role", + "tool_calls" + ], + "properties": { + "role": { + "type": "string", + "example": "assistant" + }, + "tool_calls": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolCall" + } + } + } + }, + "ToolChoice": { + "allOf": [ + { + "$ref": "#/components/schemas/ToolType" + } + ], + "nullable": true + }, "ToolType": { "oneOf": [ { - "type": "object", - "default": null, - "nullable": true + "type": "string", + "description": "Means the model can pick between generating a message or calling one or more tools.", + "enum": [ + "auto" + ] }, { - "type": "string" + "type": "string", + "description": "Means the model will not call any tool and instead generates a message.", + "enum": [ + "none" + ] }, { "type": "object", @@ -1927,7 +2269,20 @@ } } } - ] + ], + "description": "Controls which (if any) tool is called by the model.", + "example": "auto" + }, + "Url": { + "type": "object", + "required": [ + "url" + ], + "properties": { + "url": { + "type": "string" + } + } }, "Usage": { "type": "object", diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index c9b4efd982b4a24273bbc4f51efcbe4c84ca3cf9..4876f7c586d4a9c24d19c4e2053c0a94119b9e90 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -3,6 +3,8 @@ title: Text Generation Inference - local: quicktour title: Quick Tour + - local: supported_models + title: Supported Models - local: installation_nvidia title: Using TGI with Nvidia GPUs - local: installation_amd @@ -11,14 +13,15 @@ title: Using TGI with Intel Gaudi - local: installation_inferentia title: Using TGI with AWS Inferentia + - local: installation_intel + title: Using TGI with Intel GPUs - local: installation title: Installation from source - - local: supported_models - title: Supported Models and Hardware - - local: messages_api - title: Messages API + - local: architecture title: Internal Architecture + - local: usage_statistics + title: Usage Statistics title: Getting started - sections: - local: basic_tutorials/consuming_tgi @@ -29,8 +32,6 @@ title: Serving Private & Gated Models - local: basic_tutorials/using_cli title: Using TGI CLI - - local: basic_tutorials/launcher - title: All TGI CLI options - local: basic_tutorials/non_core_models title: Non-core Model Serving - local: basic_tutorials/safety @@ -44,6 +45,14 @@ - local: basic_tutorials/train_medusa title: Train Medusa title: Tutorials +- sections: + - local: reference/launcher + title: All TGI CLI options + - local: reference/metrics + title: Exported Metrics + - local: reference/api_reference + title: API Reference + title: Reference - sections: - local: conceptual/streaming title: Streaming @@ -60,9 +69,11 @@ - local: conceptual/speculation title: Speculation (Medusa, ngram) - local: conceptual/guidance - title: How Guidance Works (via outlines + title: How Guidance Works (via outlines) - local: conceptual/lora title: LoRA (Low-Rank Adaptation) + - local: conceptual/external + title: External Resources title: Conceptual Guides diff --git a/docs/source/architecture.md b/docs/source/architecture.md index a8418817ebc5156f46b667f0d7a5f053a8d9201b..6660630d0613e871fc31889aa42c69648c5c4a99 100644 --- a/docs/source/architecture.md +++ b/docs/source/architecture.md @@ -10,7 +10,7 @@ This diagram shows well there are these separate components: - **The router**, also named `webserver`, that receives the client requests, buffers them, creates some batches, and prepares gRPC calls to a model server. - **The model server**, responsible of receiving the gRPC requests and to process the inference on the model. If the model is sharded across multiple accelerators (e.g.: multiple GPUs), the model server shards might be synchronized via NCCL or equivalent. -- **The launcher** is a helper thar will be able to launch one or several model servers (if model is sharded), and it launches the router with the compatible arguments. +- **The launcher** is a helper that will be able to launch one or several model servers (if model is sharded), and it launches the router with the compatible arguments. The router and the model server can be two different machines, they do not need to be deployed together. @@ -103,6 +103,7 @@ Several variants of the model server exist that are actively supported by Huggin - By default, the model server will attempt building [a server optimized for Nvidia GPUs with CUDA](https://huggingface.co/docs/text-generation-inference/installation_nvidia). The code for this version is hosted in the [main TGI repository](https://github.com/huggingface/text-generation-inference). - A [version optimized for AMD with ROCm](https://huggingface.co/docs/text-generation-inference/installation_amd) is hosted in the main TGI repository. Some model features differ. +- A [version optimized for Intel GPUs](https://huggingface.co/docs/text-generation-inference/installation_intel) is hosted in the main TGI repository. Some model features differ. - The [version for Intel Gaudi](https://huggingface.co/docs/text-generation-inference/installation_gaudi) is maintained on a forked repository, often resynchronized with the main [TGI repository](https://github.com/huggingface/tgi-gaudi). - A [version for Neuron (AWS Inferentia2)](https://huggingface.co/docs/text-generation-inference/installation_inferentia) is maintained as part of [Optimum Neuron](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference). - A version for Google TPUs is maintained as part of [Optimum TPU](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference). diff --git a/docs/source/basic_tutorials/consuming_tgi.md b/docs/source/basic_tutorials/consuming_tgi.md index 4829ec7c9aa3b145e2cdc27959e117ec3350745b..b07e7219d7cb72978a9dc25ca02fa2b5b6e0adc2 100644 --- a/docs/source/basic_tutorials/consuming_tgi.md +++ b/docs/source/basic_tutorials/consuming_tgi.md @@ -1,81 +1,125 @@ # Consuming Text Generation Inference -There are many ways you can consume Text Generation Inference server in your applications. After launching, you can use the `/generate` route and make a `POST` request to get results from the server. You can also use the `/generate_stream` route if you want TGI to return a stream of tokens. You can make the requests using the tool of your preference, such as curl, Python or TypeScrpt. For a final end-to-end experience, we also open-sourced ChatUI, a chat interface for open-source models. +There are many ways to consume Text Generation Inference (TGI) server in your applications. After launching the server, you can use the [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) `/v1/chat/completions` route and make a `POST` request to get results from the server. You can also pass `"stream": true` to the call if you want TGI to return a stream of tokens. + +For more information on the API, consult the OpenAPI documentation of `text-generation-inference` available [here](https://huggingface.github.io/text-generation-inference). + +You can make the requests using any tool of your preference, such as curl, Python, or TypeScript. For an end-to-end experience, we've open-sourced [ChatUI](https://github.com/huggingface/chat-ui), a chat interface for open-access models. ## curl -After the launch, you can query the model using either the `/generate` or `/generate_stream` routes: +After a successful server launch, you can query the model using the `v1/chat/completions` route, to get responses that are compliant to the OpenAI Chat Completion spec: ```bash -curl 127.0.0.1:8080/generate \ +curl localhost:8080/v1/chat/completions \ -X POST \ - -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ + -d '{ + "model": "tgi", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What is deep learning?" + } + ], + "stream": true, + "max_tokens": 20 +}' \ -H 'Content-Type: application/json' ``` - -## Inference Client - -[`huggingface-hub`](https://huggingface.co/docs/huggingface_hub/main/en/index) is a Python library to interact with the Hugging Face Hub, including its endpoints. It provides a nice high-level class, [`~huggingface_hub.InferenceClient`], which makes it easy to make calls to a TGI endpoint. `InferenceClient` also takes care of parameter validation and provides a simple to-use interface. -You can simply install `huggingface-hub` package with pip. +For non-chat use-cases, you can also use the `/generate` and `/generate_stream` routes. ```bash -pip install huggingface-hub -``` - -Once you start the TGI server, instantiate `InferenceClient()` with the URL to the endpoint serving the model. You can then call `text_generation()` to hit the endpoint through Python. - -```python -from huggingface_hub import InferenceClient - -client = InferenceClient(model="http://127.0.0.1:8080") -client.text_generation(prompt="Write a code for snake game") +curl 127.0.0.1:8080/generate \ + -X POST \ + -d '{ + "inputs":"What is Deep Learning?", + "parameters":{ + "max_new_tokens":20 + } +}' \ + -H 'Content-Type: application/json' ``` -You can do streaming with `InferenceClient` by passing `stream=True`. Streaming will return tokens as they are being generated in the server. To use streaming, you can do as follows: +## Python -```python -for token in client.text_generation("How do you make cheese?", max_new_tokens=12, stream=True): - print(token) -``` +### Inference Client -Another parameter you can use with TGI backend is `details`. You can get more details on generation (tokens, probabilities, etc.) by setting `details` to `True`. When it's specified, TGI will return a `TextGenerationResponse` or `TextGenerationStreamResponse` rather than a string or stream. +[`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/main/en/index) is a Python library to interact with the Hugging Face Hub, including its endpoints. It provides a high-level class, [`huggingface_hub.InferenceClient`](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient), which makes it easy to make calls to TGI's Messages API. `InferenceClient` also takes care of parameter validation and provides a simple-to-use interface. -```python -output = client.text_generation(prompt="Meaning of life is", details=True) -print(output) +Install `huggingface_hub` package via pip. -# TextGenerationResponse(generated_text=' a complex concept that is not always clear to the individual. It is a concept that is not always', details=Details(finish_reason=, generated_tokens=20, seed=None, prefill=[], tokens=[Token(id=267, text=' a', logprob=-2.0723474, special=False), Token(id=11235, text=' complex', logprob=-3.1272552, special=False), Token(id=17908, text=' concept', logprob=-1.3632495, special=False),..)) +```bash +pip install huggingface_hub ``` -You can see how to stream below. +You can now use `InferenceClient` the exact same way you would use `OpenAI` client in Python ```python -output = client.text_generation(prompt="Meaning of life is", stream=True, details=True) -print(next(iter(output))) +from huggingface_hub import InferenceClient -# TextGenerationStreamResponse(token=Token(id=267, text=' a', logprob=-2.0723474, special=False), generated_text=None, details=None) +client = InferenceClient( + base_url="http://localhost:8080/v1/", +) + +output = client.chat.completions.create( + model="tgi", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Count to 10"}, + ], + stream=True, + max_tokens=1024, +) + +for chunk in output: + print(chunk.choices[0].delta.content) ``` -You can check out the details of the function [here](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation). There is also an async version of the client, `AsyncInferenceClient`, based on `asyncio` and `aiohttp`. You can find docs for it [here](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.AsyncInferenceClient) +You can check out more details about OpenAI compatibility [here](https://huggingface.co/docs/huggingface_hub/en/guides/inference#openai-compatibility). +There is also an async version of the client, `AsyncInferenceClient`, based on `asyncio` and `aiohttp`. You can find docs for it [here](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.AsyncInferenceClient) -## ChatUI +### OpenAI Client -ChatUI is an open-source interface built for LLM serving. It offers many customization options, such as web search with SERP API and more. ChatUI can automatically consume the TGI server and even provides an option to switch between different TGI endpoints. You can try it out at [Hugging Chat](https://huggingface.co/chat/), or use the [ChatUI Docker Space](https://huggingface.co/new-space?template=huggingchat/chat-ui-template) to deploy your own Hugging Chat to Spaces. +You can directly use the OpenAI [Python](https://github.com/openai/openai-python) or [JS](https://github.com/openai/openai-node) clients to interact with TGI. -To serve both ChatUI and TGI in same environment, simply add your own endpoints to the `MODELS` variable in `.env.local` file inside the `chat-ui` repository. Provide the endpoints pointing to where TGI is served. +Install the OpenAI Python package via pip. +```bash +pip install openai ``` -{ -// rest of the model config here -"endpoints": [{"url": "https://HOST:PORT/generate_stream"}] -} + +```python +from openai import OpenAI + +# init the client but point it to TGI +client = OpenAI( + base_url="http://localhost:8080/v1/", + api_key="-" +) + +chat_completion = client.chat.completions.create( + model="tgi", + messages=[ + {"role": "system", "content": "You are a helpful assistant." }, + {"role": "user", "content": "What is deep learning?"} + ], + stream=True +) + +# iterate and print stream +for message in chat_completion: + print(message) ``` -![ChatUI](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/chatui_screen.png) +## UI -## Gradio +### Gradio Gradio is a Python library that helps you build web applications for your machine learning models with a few lines of code. It has a `ChatInterface` wrapper that helps create neat UIs for chatbots. Let's take a look at how to create a chatbot with streaming mode using TGI and Gradio. Let's install Gradio and Hub Python library first. @@ -89,19 +133,28 @@ Assume you are serving your model on port 8080, we will query through [Inference import gradio as gr from huggingface_hub import InferenceClient -client = InferenceClient(model="http://127.0.0.1:8080") +client = InferenceClient(base_url="http://127.0.0.1:8080") def inference(message, history): partial_message = "" - for token in client.text_generation(message, max_new_tokens=20, stream=True): - partial_message += token + output = client.chat.completions.create( + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": message}, + ], + stream=True, + max_tokens=1024, + ) + + for chunk in output: + partial_message += chunk.choices[0].delta.content yield partial_message gr.ChatInterface( inference, chatbot=gr.Chatbot(height=300), textbox=gr.Textbox(placeholder="Chat with me!", container=False, scale=7), - description="This is the demo for Gradio UI consuming TGI endpoint with LLaMA 7B-Chat model.", + description="This is the demo for Gradio UI consuming TGI endpoint.", title="Gradio 🤝 TGI", examples=["Are tomatoes vegetables?"], retry_btn="Retry", @@ -110,20 +163,7 @@ gr.ChatInterface( ).queue().launch() ``` -The UI looks like this 👇 - -
- - -
- -You can try the demo directly here 👇 +You can check out the UI and try the demo directly here 👇