Commit 3cf6368c authored by OlivierDehaene's avatar OlivierDehaene
Browse files

feat(server): Support all AutoModelForCausalLM on a best effort basis

parent 09674e6d
...@@ -28,9 +28,9 @@ dependencies = [ ...@@ -28,9 +28,9 @@ dependencies = [
[[package]] [[package]]
name = "anyhow" name = "anyhow"
version = "1.0.65" version = "1.0.66"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "98161a4e3e2184da77bb14f02184cdd111e83bbbcc9979dfee3c44b9a85f5602" checksum = "216261ddc8289130e551ddcd5ce8a064710c0d064a4d2895c67151c92b5443f6"
[[package]] [[package]]
name = "async-stream" name = "async-stream"
...@@ -83,9 +83,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" ...@@ -83,9 +83,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]] [[package]]
name = "axum" name = "axum"
version = "0.5.16" version = "0.5.17"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c9e3356844c4d6a6d6467b8da2cffb4a2820be256f50a3a386c9d152bab31043" checksum = "acee9fd5073ab6b045a275b3e709c163dd36c90685219cb21804a147b58dba43"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"axum-core", "axum-core",
...@@ -114,9 +114,9 @@ dependencies = [ ...@@ -114,9 +114,9 @@ dependencies = [
[[package]] [[package]]
name = "axum-core" name = "axum-core"
version = "0.2.8" version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9f0c0a60006f2a293d82d571f635042a72edf927539b7685bd62d361963839b" checksum = "37e5939e02c56fecd5c017c37df4238c0a839fa76b7f97acdd7efb804fd181cc"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"bytes", "bytes",
...@@ -130,9 +130,9 @@ dependencies = [ ...@@ -130,9 +130,9 @@ dependencies = [
[[package]] [[package]]
name = "base64" name = "base64"
version = "0.13.0" version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd" checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
[[package]] [[package]]
name = "bitflags" name = "bitflags"
...@@ -149,21 +149,6 @@ dependencies = [ ...@@ -149,21 +149,6 @@ dependencies = [
"generic-array", "generic-array",
] ]
[[package]]
name = "bloom-inference-client"
version = "0.1.0"
dependencies = [
"futures",
"prost",
"thiserror",
"tokio",
"tonic",
"tonic-build",
"tower",
"tracing",
"tracing-error",
]
[[package]] [[package]]
name = "bumpalo" name = "bumpalo"
version = "3.11.1" version = "3.11.1"
...@@ -255,9 +240,9 @@ dependencies = [ ...@@ -255,9 +240,9 @@ dependencies = [
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.0.17" version = "4.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06badb543e734a2d6568e19a40af66ed5364360b9226184926f89d229b4b4267" checksum = "335867764ed2de42325fafe6d18b8af74ba97ee0c590fa016f157535b42ab04b"
dependencies = [ dependencies = [
"atty", "atty",
"bitflags", "bitflags",
...@@ -270,9 +255,9 @@ dependencies = [ ...@@ -270,9 +255,9 @@ dependencies = [
[[package]] [[package]]
name = "clap_derive" name = "clap_derive"
version = "4.0.13" version = "4.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c42f169caba89a7d512b5418b09864543eeb4d497416c917d7137863bd2076ad" checksum = "16a1b0f6422af32d5da0c58e2703320f379216ee70198241c84173a8c5ac28f3"
dependencies = [ dependencies = [
"heck 0.4.0", "heck 0.4.0",
"proc-macro-error", "proc-macro-error",
...@@ -532,14 +517,14 @@ dependencies = [ ...@@ -532,14 +517,14 @@ dependencies = [
[[package]] [[package]]
name = "filetime" name = "filetime"
version = "0.2.17" version = "0.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e94a7bbaa59354bc20dd75b67f23e2797b4490e9d6928203fb105c79e448c86c" checksum = "4b9663d381d07ae25dc88dbdf27df458faa83a9b25336bcac83d5e452b5fc9d3"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"libc", "libc",
"redox_syscall", "redox_syscall",
"windows-sys 0.36.1", "windows-sys 0.42.0",
] ]
[[package]] [[package]]
...@@ -600,9 +585,9 @@ dependencies = [ ...@@ -600,9 +585,9 @@ dependencies = [
[[package]] [[package]]
name = "futures" name = "futures"
version = "0.3.24" version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f21eda599937fba36daeb58a22e8f5cee2d14c4a17b5b7739c7c8e5e3b8230c" checksum = "38390104763dc37a5145a53c29c63c1290b5d316d6086ec32c293f6736051bb0"
dependencies = [ dependencies = [
"futures-channel", "futures-channel",
"futures-core", "futures-core",
...@@ -615,9 +600,9 @@ dependencies = [ ...@@ -615,9 +600,9 @@ dependencies = [
[[package]] [[package]]
name = "futures-channel" name = "futures-channel"
version = "0.3.24" version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30bdd20c28fadd505d0fd6712cdfcb0d4b5648baf45faef7f852afb2399bb050" checksum = "52ba265a92256105f45b719605a571ffe2d1f0fea3807304b522c1d778f79eed"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"futures-sink", "futures-sink",
...@@ -625,15 +610,15 @@ dependencies = [ ...@@ -625,15 +610,15 @@ dependencies = [
[[package]] [[package]]
name = "futures-core" name = "futures-core"
version = "0.3.24" version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e5aa3de05362c3fb88de6531e6296e85cde7739cccad4b9dfeeb7f6ebce56bf" checksum = "04909a7a7e4633ae6c4a9ab280aeb86da1236243a77b694a49eacd659a4bd3ac"
[[package]] [[package]]
name = "futures-executor" name = "futures-executor"
version = "0.3.24" version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ff63c23854bee61b6e9cd331d523909f238fc7636290b96826e9cfa5faa00ab" checksum = "7acc85df6714c176ab5edf386123fafe217be88c0840ec11f199441134a074e2"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"futures-task", "futures-task",
...@@ -642,15 +627,15 @@ dependencies = [ ...@@ -642,15 +627,15 @@ dependencies = [
[[package]] [[package]]
name = "futures-io" name = "futures-io"
version = "0.3.24" version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbf4d2a7a308fd4578637c0b17c7e1c7ba127b8f6ba00b29f717e9655d85eb68" checksum = "00f5fb52a06bdcadeb54e8d3671f8888a39697dcb0b81b23b55174030427f4eb"
[[package]] [[package]]
name = "futures-macro" name = "futures-macro"
version = "0.3.24" version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42cd15d1c7456c04dbdf7e88bcd69760d74f3a798d6444e16974b505b0e62f17" checksum = "bdfb8ce053d86b91919aad980c220b1fb8401a9394410e1c289ed7e66b61835d"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
...@@ -659,21 +644,21 @@ dependencies = [ ...@@ -659,21 +644,21 @@ dependencies = [
[[package]] [[package]]
name = "futures-sink" name = "futures-sink"
version = "0.3.24" version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21b20ba5a92e727ba30e72834706623d94ac93a725410b6a6b6fbc1b07f7ba56" checksum = "39c15cf1a4aa79df40f1bb462fb39676d0ad9e366c2a33b590d7c66f4f81fcf9"
[[package]] [[package]]
name = "futures-task" name = "futures-task"
version = "0.3.24" version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6508c467c73851293f390476d4491cf4d227dbabcd4170f3bb6044959b294f1" checksum = "2ffb393ac5d9a6eaa9d3fdf37ae2776656b706e200c8e16b1bdb227f5198e6ea"
[[package]] [[package]]
name = "futures-util" name = "futures-util"
version = "0.3.24" version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44fb6cb1be61cc1d2e43b262516aafcf63b241cffdb1d3fa115f91d9c7b09c90" checksum = "197676987abd2f9cadff84926f410af1c183608d36641465df73ae8211dc65d6"
dependencies = [ dependencies = [
"futures-channel", "futures-channel",
"futures-core", "futures-core",
...@@ -699,9 +684,9 @@ dependencies = [ ...@@ -699,9 +684,9 @@ dependencies = [
[[package]] [[package]]
name = "getrandom" name = "getrandom"
version = "0.2.7" version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4eb1a864a501629691edf6c15a593b7a51eebaa1e8468e9ddc623de7c9b58ec6" checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"libc", "libc",
...@@ -716,9 +701,9 @@ checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574" ...@@ -716,9 +701,9 @@ checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574"
[[package]] [[package]]
name = "h2" name = "h2"
version = "0.3.14" version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ca32592cf21ac7ccab1825cd87f6c9b3d9022c44d086172ed0966bec8af30be" checksum = "5f9f29bc9dda355256b2916cf526ab02ce0aeaaaf2bad60d65ef3f12f11dd0f4"
dependencies = [ dependencies = [
"bytes", "bytes",
"fnv", "fnv",
...@@ -967,9 +952,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" ...@@ -967,9 +952,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.135" version = "0.2.137"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68783febc7782c6c5cb401fbda4de5a9898be1762314da0bb2c10ced61f18b0c" checksum = "fc7fcc620a3bff7cdd7a365be3376c97191aeaccc2a603e600951e452615bf89"
[[package]] [[package]]
name = "lock_api" name = "lock_api"
...@@ -992,9 +977,9 @@ dependencies = [ ...@@ -992,9 +977,9 @@ dependencies = [
[[package]] [[package]]
name = "macro_rules_attribute" name = "macro_rules_attribute"
version = "0.1.2" version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "258c86475e1616d6f2d8f5227cfaabd3dae1f6d5388b9597df8a199d4497aba7" checksum = "cf0c9b980bf4f3a37fd7b1c066941dd1b1d0152ce6ee6e8fe8c49b9f6810d862"
dependencies = [ dependencies = [
"macro_rules_attribute-proc_macro", "macro_rules_attribute-proc_macro",
"paste", "paste",
...@@ -1002,9 +987,9 @@ dependencies = [ ...@@ -1002,9 +987,9 @@ dependencies = [
[[package]] [[package]]
name = "macro_rules_attribute-proc_macro" name = "macro_rules_attribute-proc_macro"
version = "0.1.2" version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26a8d2502d5aa4d411ef494ba7470eb299f05725179ce3b5de77aa01a9ffdea" checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d"
[[package]] [[package]]
name = "matchit" name = "matchit"
...@@ -1050,14 +1035,14 @@ dependencies = [ ...@@ -1050,14 +1035,14 @@ dependencies = [
[[package]] [[package]]
name = "mio" name = "mio"
version = "0.8.4" version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57ee1c23c7c63b0c9250c339ffdc69255f110b298b901b9f6c82547b7b87caaf" checksum = "e5d732bc30207a6423068df043e3d02e0735b155ad7ce1a6f76fe2baa5b158de"
dependencies = [ dependencies = [
"libc", "libc",
"log", "log",
"wasi 0.11.0+wasi-snapshot-preview1", "wasi 0.11.0+wasi-snapshot-preview1",
"windows-sys 0.36.1", "windows-sys 0.42.0",
] ]
[[package]] [[package]]
...@@ -1200,9 +1185,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" ...@@ -1200,9 +1185,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
[[package]] [[package]]
name = "openssl-sys" name = "openssl-sys"
version = "0.9.76" version = "0.9.77"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5230151e44c0f05157effb743e8d517472843121cf9243e8b81393edb5acd9ce" checksum = "b03b84c3b2d099b81f0953422b4d4ad58761589d0229b5506356afca05a3670a"
dependencies = [ dependencies = [
"autocfg", "autocfg",
"cc", "cc",
...@@ -1213,9 +1198,9 @@ dependencies = [ ...@@ -1213,9 +1198,9 @@ dependencies = [
[[package]] [[package]]
name = "os_str_bytes" name = "os_str_bytes"
version = "6.3.0" version = "6.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ff7415e9ae3fff1225851df9e0d9e4e5479f947619774677a63572e55e80eff" checksum = "3baf96e39c5359d2eb0dd6ccb42c62b91d9678aa68160d261b9e0ccbf9e9dea9"
[[package]] [[package]]
name = "overload" name = "overload"
...@@ -1302,9 +1287,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" ...@@ -1302,9 +1287,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]] [[package]]
name = "pkg-config" name = "pkg-config"
version = "0.3.25" version = "0.3.26"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae" checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160"
[[package]] [[package]]
name = "ppv-lite86" name = "ppv-lite86"
...@@ -1602,18 +1587,18 @@ dependencies = [ ...@@ -1602,18 +1587,18 @@ dependencies = [
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.145" version = "1.0.147"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "728eb6351430bccb993660dfffc5a72f91ccc1295abaa8ce19b27ebe4f75568b" checksum = "d193d69bae983fc11a79df82342761dfbf28a99fc8d203dca4c3c1b590948965"
dependencies = [ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.145" version = "1.0.147"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81fa1584d3d1bcacd84c277a0dfe21f5b0f6accf4a23d04d4c6d61f1af522b4c" checksum = "4f1d362ca8fc9c3e3a7484440752472d68a6caa98f1ab81d99b5dfe517cec852"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
...@@ -1622,9 +1607,9 @@ dependencies = [ ...@@ -1622,9 +1607,9 @@ dependencies = [
[[package]] [[package]]
name = "serde_json" name = "serde_json"
version = "1.0.86" version = "1.0.87"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41feea4228a6f1cd09ec7a3593a682276702cd67b5273544757dae23c096f074" checksum = "6ce777b7b150d76b9cf60d28b55f5847135a003f7d7350c6be7a773508ce7d45"
dependencies = [ dependencies = [
"itoa", "itoa",
"ryu", "ryu",
...@@ -1739,9 +1724,9 @@ dependencies = [ ...@@ -1739,9 +1724,9 @@ dependencies = [
[[package]] [[package]]
name = "syn" name = "syn"
version = "1.0.102" version = "1.0.103"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fcd952facd492f9be3ef0d0b7032a6e442ee9b361d4acc2b1d0c4aaa5f613a1" checksum = "a864042229133ada95abf3b54fdc62ef5ccabe9515b64717bcb9a1919e59445d"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
...@@ -1798,11 +1783,26 @@ dependencies = [ ...@@ -1798,11 +1783,26 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "text-generation-client"
version = "0.1.0"
dependencies = [
"futures",
"prost",
"thiserror",
"tokio",
"tonic",
"tonic-build",
"tower",
"tracing",
"tracing-error",
]
[[package]] [[package]]
name = "text-generation-launcher" name = "text-generation-launcher"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"clap 4.0.17", "clap 4.0.18",
"ctrlc", "ctrlc",
"subprocess", "subprocess",
"tracing", "tracing",
...@@ -1814,12 +1814,12 @@ name = "text-generation-router" ...@@ -1814,12 +1814,12 @@ name = "text-generation-router"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"axum", "axum",
"bloom-inference-client", "clap 4.0.18",
"clap 4.0.17",
"futures", "futures",
"parking_lot", "parking_lot",
"serde", "serde",
"serde_json", "serde_json",
"text-generation-client",
"thiserror", "thiserror",
"tokenizers", "tokenizers",
"tokio", "tokio",
......
...@@ -66,7 +66,7 @@ COPY proto proto ...@@ -66,7 +66,7 @@ COPY proto proto
COPY server server COPY server server
RUN cd server && \ RUN cd server && \
make gen-server && \ make gen-server && \
/opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir /opt/miniconda/envs/text-generation/bin/pip install ".[bnb]" --no-cache-dir
# Install router # Install router
COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router
......
...@@ -22,7 +22,7 @@ run-bloom-560m-quantize: ...@@ -22,7 +22,7 @@ run-bloom-560m-quantize:
text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2 --quantize text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2 --quantize
download-bloom: download-bloom:
bloom-inference-server download-weights bigscience/bloom text-generation-server download-weights bigscience/bloom
run-bloom: run-bloom:
text-generation-launcher --model-name bigscience/bloom --num-shard 8 text-generation-launcher --model-name bigscience/bloom --num-shard 8
......
...@@ -15,11 +15,13 @@ A Rust and gRPC server for large language models text generation inference. ...@@ -15,11 +15,13 @@ A Rust and gRPC server for large language models text generation inference.
- [Safetensors](https://github.com/huggingface/safetensors) weight loading - [Safetensors](https://github.com/huggingface/safetensors) weight loading
- 45ms per token generation for BLOOM with 8xA100 80GB - 45ms per token generation for BLOOM with 8xA100 80GB
## Supported models ## Officially supported models
- BLOOM - BLOOM
- BLOOM-560m - BLOOM-560m
Other models are supported on a best-effort basis using `AutoModelForCausalLM.from_pretrained(<model>, torch_dtype=torch.float16, device_map="auto")`.
## Load Tests for BLOOM ## Load Tests for BLOOM
See `k6/load_test.js` See `k6/load_test.js`
......
$schema: https://azuremlschemas.azureedge.net/latest/model.schema.json $schema: https://azuremlschemas.azureedge.net/latest/model.schema.json
name: bloom name: bloom-safetensors
version: 1 version: 1
path: ./bloom path: ./bloom-safetensors
type: custom_model type: custom_model
...@@ -256,7 +256,7 @@ fn shard_manager( ...@@ -256,7 +256,7 @@ fn shard_manager(
// Process args // Process args
let mut shard_argv = vec![ let mut shard_argv = vec![
"bloom-inference-server".to_string(), "text-generation-server".to_string(),
"serve".to_string(), "serve".to_string(),
model_name, model_name,
"--uds-path".to_string(), "--uds-path".to_string(),
...@@ -311,7 +311,7 @@ fn shard_manager( ...@@ -311,7 +311,7 @@ fn shard_manager(
Err(err) => { Err(err) => {
if let PopenError::IoError(ref err) = err { if let PopenError::IoError(ref err) = err {
if err.kind() == io::ErrorKind::NotFound { if err.kind() == io::ErrorKind::NotFound {
tracing::error!("bloom-inference-server not found in PATH"); tracing::error!("text-generation-server not found in PATH");
tracing::error!("Please install it with `make install-server`") tracing::error!("Please install it with `make install-server`")
} }
} }
......
...@@ -14,7 +14,7 @@ path = "src/main.rs" ...@@ -14,7 +14,7 @@ path = "src/main.rs"
[dependencies] [dependencies]
axum = { version = "0.5.16", features = ["json", "serde_json"] } axum = { version = "0.5.16", features = ["json", "serde_json"] }
bloom-inference-client = { path = "client" } text-generation-client = { path = "client" }
clap = { version = "4.0.15", features = ["derive", "env"] } clap = { version = "4.0.15", features = ["derive", "env"] }
futures = "0.3.24" futures = "0.3.24"
parking_lot = "0.12.1" parking_lot = "0.12.1"
......
[package] [package]
name = "bloom-inference-client" name = "text-generation-client"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
......
...@@ -3,9 +3,9 @@ use crate::{Db, Entry}; ...@@ -3,9 +3,9 @@ use crate::{Db, Entry};
use crate::{ErrorResponse, GenerateRequest}; use crate::{ErrorResponse, GenerateRequest};
use axum::http::StatusCode; use axum::http::StatusCode;
use axum::Json; use axum::Json;
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
use std::future::Future; use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient};
use thiserror::Error; use thiserror::Error;
use tokio::sync::{oneshot, Notify}; use tokio::sync::{oneshot, Notify};
use tokio::time::Instant; use tokio::time::Instant;
......
use crate::InferResponse; use crate::InferResponse;
/// This code is massively inspired by Tokio mini-redis /// This code is massively inspired by Tokio mini-redis
use crate::{GenerateParameters, GenerateRequest}; use crate::{GenerateParameters, GenerateRequest};
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
use parking_lot::Mutex; use parking_lot::Mutex;
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{Batch, ClientError, LogitsWarperParameters, Request};
use tokio::sync::oneshot::Sender; use tokio::sync::oneshot::Sender;
use tokio::time::Instant; use tokio::time::Instant;
......
/// Text Generation Inference webserver entrypoint
use bloom_inference_client::ShardedClient;
use clap::Parser; use clap::Parser;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
/// Text Generation Inference webserver entrypoint
use text_generation_client::ShardedClient;
use text_generation_router::server; use text_generation_router::server;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
...@@ -19,7 +19,7 @@ struct Args { ...@@ -19,7 +19,7 @@ struct Args {
max_waiting_tokens: usize, max_waiting_tokens: usize,
#[clap(default_value = "3000", long, short, env)] #[clap(default_value = "3000", long, short, env)]
port: u16, port: u16,
#[clap(default_value = "/tmp/bloom-inference-0", long, env)] #[clap(default_value = "/tmp/text-generation-0", long, env)]
master_shard_uds_path: String, master_shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)] #[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String, tokenizer_name: String,
......
...@@ -6,9 +6,9 @@ use axum::http::{HeaderMap, StatusCode}; ...@@ -6,9 +6,9 @@ use axum::http::{HeaderMap, StatusCode};
use axum::response::IntoResponse; use axum::response::IntoResponse;
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{Json, Router}; use axum::{Json, Router};
use bloom_inference_client::ShardedClient;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::ShardedClient;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::signal; use tokio::signal;
use tokio::sync::Semaphore; use tokio::sync::Semaphore;
......
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
bloom_inference/__pycache__/ text_generation/__pycache__/
bloom_inference/pb/__pycache__/ text_generation/pb/__pycache__/
*.py[cod] *.py[cod]
*$py.class *$py.class
......
gen-server: gen-server:
# Compile protos # Compile protos
pip install grpcio-tools==1.49.1 --no-cache-dir pip install grpcio-tools==1.49.1 --no-cache-dir
mkdir bloom_inference/pb || true mkdir text_generation/pb || true
python -m grpc_tools.protoc -I../proto --python_out=bloom_inference/pb --grpc_python_out=bloom_inference/pb ../proto/generate.proto python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto
find bloom_inference/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch bloom_inference/pb/__init__.py touch text_generation/pb/__init__.py
install-transformers: install-transformers:
# Install specific version of transformers # Install specific version of transformers
...@@ -36,4 +36,4 @@ install: gen-server install-torch install-transformers install-safetensors ...@@ -36,4 +36,4 @@ install: gen-server install-torch install-transformers install-safetensors
pip install -e . --no-cache-dir pip install -e . --no-cache-dir
run-dev: run-dev:
python -m torch.distributed.run --nproc_per_node=2 bloom_inference/cli.py serve bigscience/bloom-560m --sharded python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded
\ No newline at end of file \ No newline at end of file
[tool.poetry] [tool.poetry]
name = "bloom-inference" name = "text-generation"
version = "0.1.0" version = "0.1.0"
description = "BLOOM Inference Python gRPC Server" description = "BLOOM Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]
[tool.poetry.scripts] [tool.poetry.scripts]
bloom-inference-server = 'bloom_inference.cli:app' text-generation-server = 'text_generation.cli:app'
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.9" python = "^3.9"
...@@ -17,6 +17,9 @@ accelerate = "^0.12.0" ...@@ -17,6 +17,9 @@ accelerate = "^0.12.0"
joblib = "^1.2.0" joblib = "^1.2.0"
bitsandbytes = "^0.35.1" bitsandbytes = "^0.35.1"
[tool.poetry.extras]
bnb = ["bitsandbytes"]
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
grpcio-tools = "^1.49.1" grpcio-tools = "^1.49.1"
......
from bloom_inference.model import Batch
from typing import Dict, Optional from typing import Dict, Optional
from text_generation.models.types import Batch
class Cache: class Cache:
def __init__(self): def __init__(self):
......
...@@ -3,7 +3,7 @@ import typer ...@@ -3,7 +3,7 @@ import typer
from pathlib import Path from pathlib import Path
from bloom_inference import server, utils from text_generation import server, utils
app = typer.Typer() app = typer.Typer()
...@@ -13,7 +13,7 @@ def serve( ...@@ -13,7 +13,7 @@ def serve(
model_name: str, model_name: str,
sharded: bool = False, sharded: bool = False,
quantize: bool = False, quantize: bool = False,
uds_path: Path = "/tmp/bloom-inference", uds_path: Path = "/tmp/text-generation",
): ):
if sharded: if sharded:
assert ( assert (
...@@ -35,8 +35,9 @@ def serve( ...@@ -35,8 +35,9 @@ def serve(
@app.command() @app.command()
def download_weights( def download_weights(
model_name: str, model_name: str,
extension: str = ".safetensors",
): ):
utils.download_weights(model_name) utils.download_weights(model_name, extension)
if __name__ == "__main__": if __name__ == "__main__":
......
from text_generation.models.model import Model
from text_generation.models.bloom import BLOOMSharded
__all__ = ["Model", "BLOOMSharded"]
def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
if model_name.startswith("bigscience/bloom"):
if sharded:
return BLOOMSharded(model_name, quantize)
else:
if quantize:
raise ValueError("quantization is not supported for non-sharded BLOOM")
return Model(model_name)
else:
if sharded:
raise ValueError("sharded is only supported for BLOOM")
if quantize:
raise ValueError("Quantization is only supported for BLOOM models")
return Model(model_name)
import torch import torch
import torch.distributed import torch.distributed
from dataclasses import dataclass from typing import List, Optional
from typing import List, Tuple, Optional, Dict
from accelerate import init_empty_weights from accelerate import init_empty_weights
from safetensors import safe_open from safetensors import safe_open
...@@ -13,10 +12,8 @@ from transformers.models.bloom.parallel_layers import ( ...@@ -13,10 +12,8 @@ from transformers.models.bloom.parallel_layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from bloom_inference.pb import generate_pb2 from text_generation.models import Model
from bloom_inference.utils import ( from text_generation.utils import (
StoppingCriteria,
NextTokenChooser,
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
download_weights, download_weights,
...@@ -32,359 +29,9 @@ except Exception as e: ...@@ -32,359 +29,9 @@ except Exception as e:
torch.manual_seed(0) torch.manual_seed(0)
@dataclass class BLOOMSharded(Model):
class Batch:
batch_id: int
requests: List[generate_pb2.Request]
all_input_lengths: List[int]
input_ids: Dict[str, torch.Tensor]
all_input_ids: List[torch.Tensor]
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
size: int
max_sequence_length: int
def to_pb(self):
return generate_pb2.Batch(
id=self.batch_id,
requests=self.requests,
size=self.size,
max_sequence_length=self.max_sequence_length,
)
@classmethod
def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
) -> "Batch":
inputs = []
next_token_choosers = []
stopping_criterias = []
all_input_lengths = []
# Parse batch
for r in pb.requests:
inputs.append(r.inputs)
all_input_lengths.append(r.input_length)
next_token_choosers.append(
NextTokenChooser(
temperature=r.parameters.temperature,
top_k=r.parameters.top_k,
top_p=r.parameters.top_p,
do_sample=r.parameters.do_sample,
)
)
stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens))
input_ids = tokenizer(
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
).to(device)
all_input_ids = input_ids["input_ids"].unsqueeze(-1)
return cls(
batch_id=pb.id,
requests=pb.requests,
all_input_lengths=all_input_lengths,
input_ids=input_ids,
all_input_ids=all_input_ids,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=pb.size,
max_sequence_length=pb.max_sequence_length,
)
@classmethod
def concatenate(cls, batches: List["Batch"]) -> "Batch":
# Used for padding
total_batch_size = sum(batch.size for batch in batches)
max_sequence_length = max(batch.max_sequence_length for batch in batches)
# Batch attributes
input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
requests = []
all_input_lengths = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
all_input_lengths.extend(batch.all_input_lengths)
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
# Slicing end index for this batch
end_index = start_index + batch.size
# We only concatenate batches that did at least one step
if batch.input_ids["input_ids"].shape[1] > 1:
raise ValueError("Batch input_ids should be of shape (batch_size, 1)")
# Initialize tensors
if i == 0:
input_ids["input_ids"] = torch.empty(
(total_batch_size, 1),
dtype=batch.input_ids["input_ids"].dtype,
device=batch.input_ids["input_ids"].device,
)
input_ids["attention_mask"] = torch.zeros(
(total_batch_size, max_sequence_length),
dtype=batch.input_ids["attention_mask"].dtype,
device=batch.input_ids["attention_mask"].device,
)
# input_ids["input_ids"] is always of shape [batch_size, 1]
# We do not need to pad it
input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"]
# We need to slice the attention mask to remove padding from previous steps
input_ids["attention_mask"][
start_index:end_index, -batch.max_sequence_length :
] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length :]
for j, past in enumerate(batch.input_ids["past_key_values"]):
past_keys = past[0]
past_values = past[1]
_, head_dim, padded_sequence_length = past_keys.shape
# Reshape the tensors to make slicing easier
past_keys = past_keys.view(
batch.size, -1, head_dim, padded_sequence_length
)
past_values = past_values.view(
batch.size, -1, padded_sequence_length, head_dim
)
num_heads = past_keys.shape[1]
# Initialize tensors
# This will run only once per layer
if j == len(input_ids["past_key_values"]):
padded_past_keys = torch.zeros(
(
total_batch_size,
num_heads,
head_dim,
max_sequence_length - 1,
),
dtype=past_keys.dtype,
device=past_keys.device,
)
padded_past_values = torch.zeros(
(
total_batch_size,
num_heads,
max_sequence_length - 1,
head_dim,
),
dtype=past_values.dtype,
device=past_values.device,
)
input_ids["past_key_values"].append(
[padded_past_keys, padded_past_values]
)
# We slice the past keys and values to remove the padding from previous batches
input_ids["past_key_values"][j][0][
start_index:end_index, :, :, -(batch.max_sequence_length - 1) :
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
input_ids["past_key_values"][j][1][
start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
# If we are on the last batch, we need to reshape the tensors
if (i + 1) == len(batches):
input_ids["past_key_values"][j][0] = input_ids["past_key_values"][
j
][0].view(total_batch_size * num_heads, head_dim, -1)
input_ids["past_key_values"][j][1] = input_ids["past_key_values"][
j
][1].view(total_batch_size * num_heads, -1, head_dim)
start_index += batch.size
return cls(
batch_id=batches[0].batch_id,
requests=requests,
all_input_lengths=all_input_lengths,
input_ids=input_ids,
all_input_ids=all_input_ids,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=total_batch_size,
max_sequence_length=max_sequence_length,
)
@dataclass
class GeneratedText:
request: generate_pb2.Request
output: str
def to_pb(self) -> generate_pb2.GeneratedText:
return generate_pb2.GeneratedText(request=self.request, output=self.output)
class BLOOM:
def __init__(self, model_name: str):
if torch.cuda.is_available():
self.device = torch.device("cuda")
dtype = torch.bfloat16
else:
self.device = torch.device("cpu")
dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
self.model = (
AutoModelForCausalLM.from_pretrained(model_name)
.eval()
.to(self.device)
.to(dtype)
)
self.num_heads = self.model.base_model.num_heads
def forward(self, input_ids, attention_mask, past_key_values: Optional = None):
# Model Forward
return self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
def generate_token(
self, batch: Batch
) -> Tuple[List[GeneratedText], Optional[Batch]]:
with torch.inference_mode():
outputs = self.forward(**batch.input_ids)
# List of indices to cache
next_batch_keep_indices = []
next_batch_past_keep_indices = []
# New input_ids for next forward
next_batch_input_ids = []
next_batch_all_input_ids = []
next_all_input_lengths = []
next_batch_size = 0
next_batch_max_sequence_length = 0
# Finished requests
generated_texts: List[GeneratedText] = []
# Zipped iterator
iterator = zip(
batch.requests,
batch.all_input_lengths,
outputs.logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
)
# For each member of the batch
for i, (
request,
input_length,
logits,
next_token_chooser,
stopping_criteria,
all_tokens,
) in enumerate(iterator):
# Select next token
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
# Append next token to all tokens
all_tokens = torch.cat([all_tokens, next_token])
# Evaluate stopping criteria
if stopping_criteria(all_tokens):
# Decode all tokens
output = self.tokenizer.decode(
all_tokens.squeeze(-1), skip_special_tokens=True
)
# Add to the list of finished generations with the original request
generated_texts.append(GeneratedText(request, output))
# add to the next batch
else:
next_batch_keep_indices.append(i)
# past_key_values is of shape [batch_size * num_heads, ...]
# so we need to take into account the `num_heads` stride here
next_batch_past_keep_indices.extend(
[j for j in range(i * self.num_heads, (i + 1) * self.num_heads)]
)
next_batch_input_ids.append(next_token)
next_batch_all_input_ids.append(all_tokens)
next_batch_size += 1
new_input_length = input_length + 1
next_all_input_lengths.append(new_input_length)
next_batch_max_sequence_length = max(
next_batch_max_sequence_length, new_input_length
)
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:
return generated_texts, None
# If we finished at least one generation
next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)}
if generated_texts:
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][
next_batch_keep_indices
]
next_batch_input_ids["past_key_values"] = [
(
keys[next_batch_past_keep_indices],
values[next_batch_past_keep_indices],
)
for keys, values in outputs["past_key_values"]
]
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"]
next_batch_input_ids["past_key_values"] = outputs["past_key_values"]
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
next_batch_input_ids["attention_mask"] = torch.cat(
[
next_batch_input_ids["attention_mask"],
torch.ones((next_batch_size, 1)).to(self.device),
],
dim=1,
)
next_batch = Batch(
batch_id=batch.batch_id,
requests=next_batch_requests,
all_input_lengths=next_all_input_lengths,
input_ids=next_batch_input_ids,
all_input_ids=next_batch_all_input_ids,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,
max_sequence_length=next_batch_max_sequence_length,
)
return generated_texts, next_batch
class BLOOMSharded(BLOOM):
def __init__(self, model_name: str, quantize: bool = False): def __init__(self, model_name: str, quantize: bool = False):
super(BLOOM, self).__init__() super(Model, self).__init__()
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = self.rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -411,10 +58,10 @@ class BLOOMSharded(BLOOM): ...@@ -411,10 +58,10 @@ class BLOOMSharded(BLOOM):
# Only download weights for small models # Only download weights for small models
if self.master and model_name == "bigscience/bloom-560m": if self.master and model_name == "bigscience/bloom-560m":
download_weights(model_name) download_weights(model_name, extension=".safetensors")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_name) filenames = weight_files(model_name, extension=".safetensors")
with init_empty_weights(): with init_empty_weights():
model = AutoModelForCausalLM.from_config(config) model = AutoModelForCausalLM.from_config(config)
...@@ -500,7 +147,9 @@ class BLOOMSharded(BLOOM): ...@@ -500,7 +147,9 @@ class BLOOMSharded(BLOOM):
if quantize: if quantize:
if not HAS_BITS_AND_BYTES: if not HAS_BITS_AND_BYTES:
raise ImportError( raise ImportError(
"bitsandbytes is not available on your machine" "bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
) )
if ( if (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment