Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
3cf6368c
Commit
3cf6368c
authored
Oct 28, 2022
by
OlivierDehaene
Browse files
feat(server): Support all AutoModelForCausalLM on a best effort basis
parent
09674e6d
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
366 additions
and
106 deletions
+366
-106
Cargo.lock
Cargo.lock
+78
-78
Dockerfile
Dockerfile
+1
-1
Makefile
Makefile
+1
-1
README.md
README.md
+3
-1
aml/model.yaml
aml/model.yaml
+2
-2
launcher/src/main.rs
launcher/src/main.rs
+2
-2
router/Cargo.toml
router/Cargo.toml
+1
-1
router/client/Cargo.toml
router/client/Cargo.toml
+1
-1
router/src/batcher.rs
router/src/batcher.rs
+1
-1
router/src/db.rs
router/src/db.rs
+1
-1
router/src/main.rs
router/src/main.rs
+3
-3
router/src/server.rs
router/src/server.rs
+1
-1
server/.gitignore
server/.gitignore
+2
-2
server/Makefile
server/Makefile
+5
-5
server/pyproject.toml
server/pyproject.toml
+5
-2
server/text_generation/__init__.py
server/text_generation/__init__.py
+0
-0
server/text_generation/cache.py
server/text_generation/cache.py
+2
-1
server/text_generation/cli.py
server/text_generation/cli.py
+4
-3
server/text_generation/models/__init__.py
server/text_generation/models/__init__.py
+22
-0
server/text_generation/models/bloom.py
server/text_generation/models/bloom.py
+231
-0
No files found.
Cargo.lock
View file @
3cf6368c
...
@@ -28,9 +28,9 @@ dependencies = [
...
@@ -28,9 +28,9 @@ dependencies = [
[[package]]
[[package]]
name = "anyhow"
name = "anyhow"
version = "1.0.6
5
"
version = "1.0.6
6
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
98161a4e3e2184da77bb14f02184cdd111e83bbbcc9979dfee3c44b9a85f5602
"
checksum = "
216261ddc8289130e551ddcd5ce8a064710c0d064a4d2895c67151c92b5443f6
"
[[package]]
[[package]]
name = "async-stream"
name = "async-stream"
...
@@ -83,9 +83,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
...
@@ -83,9 +83,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
[[package]]
name = "axum"
name = "axum"
version = "0.5.1
6
"
version = "0.5.1
7
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
c9e3356844c4d6a6d6467b8da2cffb4a2820be256f50a3a386c9d152bab310
43"
checksum = "
acee9fd5073ab6b045a275b3e709c163dd36c90685219cb21804a147b58dba
43"
dependencies = [
dependencies = [
"async-trait",
"async-trait",
"axum-core",
"axum-core",
...
@@ -114,9 +114,9 @@ dependencies = [
...
@@ -114,9 +114,9 @@ dependencies = [
[[package]]
[[package]]
name = "axum-core"
name = "axum-core"
version = "0.2.
8
"
version = "0.2.
9
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
d9f0c0a60006f2a293d82d571f635042a72edf927539b7685bd62d361963839b
"
checksum = "
37e5939e02c56fecd5c017c37df4238c0a839fa76b7f97acdd7efb804fd181cc
"
dependencies = [
dependencies = [
"async-trait",
"async-trait",
"bytes",
"bytes",
...
@@ -130,9 +130,9 @@ dependencies = [
...
@@ -130,9 +130,9 @@ dependencies = [
[[package]]
[[package]]
name = "base64"
name = "base64"
version = "0.13.
0
"
version = "0.13.
1
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9
04dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd
"
checksum = "9
e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8
"
[[package]]
[[package]]
name = "bitflags"
name = "bitflags"
...
@@ -149,21 +149,6 @@ dependencies = [
...
@@ -149,21 +149,6 @@ dependencies = [
"generic-array",
"generic-array",
]
]
[[package]]
name = "bloom-inference-client"
version = "0.1.0"
dependencies = [
"futures",
"prost",
"thiserror",
"tokio",
"tonic",
"tonic-build",
"tower",
"tracing",
"tracing-error",
]
[[package]]
[[package]]
name = "bumpalo"
name = "bumpalo"
version = "3.11.1"
version = "3.11.1"
...
@@ -255,9 +240,9 @@ dependencies = [
...
@@ -255,9 +240,9 @@ dependencies = [
[[package]]
[[package]]
name = "clap"
name = "clap"
version = "4.0.1
7
"
version = "4.0.1
8
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
06badb543e734a2d6568e19a40af66ed5364360b9226184926f89d229b4b4267
"
checksum = "
335867764ed2de42325fafe6d18b8af74ba97ee0c590fa016f157535b42ab04b
"
dependencies = [
dependencies = [
"atty",
"atty",
"bitflags",
"bitflags",
...
@@ -270,9 +255,9 @@ dependencies = [
...
@@ -270,9 +255,9 @@ dependencies = [
[[package]]
[[package]]
name = "clap_derive"
name = "clap_derive"
version = "4.0.1
3
"
version = "4.0.1
8
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
c42f169caba89a7d512b5418b09864543eeb4d497416c917d7137863bd2076ad
"
checksum = "
16a1b0f6422af32d5da0c58e2703320f379216ee70198241c84173a8c5ac28f3
"
dependencies = [
dependencies = [
"heck 0.4.0",
"heck 0.4.0",
"proc-macro-error",
"proc-macro-error",
...
@@ -532,14 +517,14 @@ dependencies = [
...
@@ -532,14 +517,14 @@ dependencies = [
[[package]]
[[package]]
name = "filetime"
name = "filetime"
version = "0.2.1
7
"
version = "0.2.1
8
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
e94a7bbaa59354bc20dd75b67f23e2797b4490e9d6928203fb105c79e448c86c
"
checksum = "
4b9663d381d07ae25dc88dbdf27df458faa83a9b25336bcac83d5e452b5fc9d3
"
dependencies = [
dependencies = [
"cfg-if",
"cfg-if",
"libc",
"libc",
"redox_syscall",
"redox_syscall",
"windows-sys 0.
36.1
",
"windows-sys 0.
42.0
",
]
]
[[package]]
[[package]]
...
@@ -600,9 +585,9 @@ dependencies = [
...
@@ -600,9 +585,9 @@ dependencies = [
[[package]]
[[package]]
name = "futures"
name = "futures"
version = "0.3.2
4
"
version = "0.3.2
5
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
7f21eda599937fba36daeb58a22e8f5cee2d14c4a17b5b7739c7c8e5e3b8230c
"
checksum = "
38390104763dc37a5145a53c29c63c1290b5d316d6086ec32c293f6736051bb0
"
dependencies = [
dependencies = [
"futures-channel",
"futures-channel",
"futures-core",
"futures-core",
...
@@ -615,9 +600,9 @@ dependencies = [
...
@@ -615,9 +600,9 @@ dependencies = [
[[package]]
[[package]]
name = "futures-channel"
name = "futures-channel"
version = "0.3.2
4
"
version = "0.3.2
5
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
30bdd20c28fadd505d0fd6712cdfcb0d4b5648baf45faef7f852afb2399bb050
"
checksum = "
52ba265a92256105f45b719605a571ffe2d1f0fea3807304b522c1d778f79eed
"
dependencies = [
dependencies = [
"futures-core",
"futures-core",
"futures-sink",
"futures-sink",
...
@@ -625,15 +610,15 @@ dependencies = [
...
@@ -625,15 +610,15 @@ dependencies = [
[[package]]
[[package]]
name = "futures-core"
name = "futures-core"
version = "0.3.2
4
"
version = "0.3.2
5
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
4e5aa3de05362c3fb88de6531e6296e85cde7739cccad4b9dfeeb7f6ebce56bf
"
checksum = "
04909a7a7e4633ae6c4a9ab280aeb86da1236243a77b694a49eacd659a4bd3ac
"
[[package]]
[[package]]
name = "futures-executor"
name = "futures-executor"
version = "0.3.2
4
"
version = "0.3.2
5
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
9ff63c23854bee61b6e9cd331d523909f238fc7636290b96826e9cfa5faa00ab
"
checksum = "
7acc85df6714c176ab5edf386123fafe217be88c0840ec11f199441134a074e2
"
dependencies = [
dependencies = [
"futures-core",
"futures-core",
"futures-task",
"futures-task",
...
@@ -642,15 +627,15 @@ dependencies = [
...
@@ -642,15 +627,15 @@ dependencies = [
[[package]]
[[package]]
name = "futures-io"
name = "futures-io"
version = "0.3.2
4
"
version = "0.3.2
5
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
bbf4d2a7a308fd4578637c0b17c7e1c7ba127b8f6ba00b29f717e9655d85eb68
"
checksum = "
00f5fb52a06bdcadeb54e8d3671f8888a39697dcb0b81b23b55174030427f4eb
"
[[package]]
[[package]]
name = "futures-macro"
name = "futures-macro"
version = "0.3.2
4
"
version = "0.3.2
5
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
42cd15d1c7456c04dbdf7e88bcd69760d74f3a798d6444e16974b505b0e62f17
"
checksum = "
bdfb8ce053d86b91919aad980c220b1fb8401a9394410e1c289ed7e66b61835d
"
dependencies = [
dependencies = [
"proc-macro2",
"proc-macro2",
"quote",
"quote",
...
@@ -659,21 +644,21 @@ dependencies = [
...
@@ -659,21 +644,21 @@ dependencies = [
[[package]]
[[package]]
name = "futures-sink"
name = "futures-sink"
version = "0.3.2
4
"
version = "0.3.2
5
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
21b20ba5a92e727ba30e72834706623d94ac93a725410b6a6b6fbc1b07f7ba56
"
checksum = "
39c15cf1a4aa79df40f1bb462fb39676d0ad9e366c2a33b590d7c66f4f81fcf9
"
[[package]]
[[package]]
name = "futures-task"
name = "futures-task"
version = "0.3.2
4
"
version = "0.3.2
5
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
a6508c467c73851293f390476d4491cf4d227dbabcd4170f3bb6044959b294f1
"
checksum = "
2ffb393ac5d9a6eaa9d3fdf37ae2776656b706e200c8e16b1bdb227f5198e6ea
"
[[package]]
[[package]]
name = "futures-util"
name = "futures-util"
version = "0.3.2
4
"
version = "0.3.2
5
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
44fb6cb1be61cc1d2e43b262516aafcf63b241cffdb1d3fa115f91d9c7b09c90
"
checksum = "
197676987abd2f9cadff84926f410af1c183608d36641465df73ae8211dc65d6
"
dependencies = [
dependencies = [
"futures-channel",
"futures-channel",
"futures-core",
"futures-core",
...
@@ -699,9 +684,9 @@ dependencies = [
...
@@ -699,9 +684,9 @@ dependencies = [
[[package]]
[[package]]
name = "getrandom"
name = "getrandom"
version = "0.2.
7
"
version = "0.2.
8
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
4eb1a864a501629691edf6c15a593b7a51eebaa1e8468e9ddc623de7c9b58ec6
"
checksum = "
c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31
"
dependencies = [
dependencies = [
"cfg-if",
"cfg-if",
"libc",
"libc",
...
@@ -716,9 +701,9 @@ checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574"
...
@@ -716,9 +701,9 @@ checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574"
[[package]]
[[package]]
name = "h2"
name = "h2"
version = "0.3.1
4
"
version = "0.3.1
5
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5
ca32592cf21ac7ccab1825cd87f6c9b3d9022c44d086172ed0966bec8af30be
"
checksum = "5
f9f29bc9dda355256b2916cf526ab02ce0aeaaaf2bad60d65ef3f12f11dd0f4
"
dependencies = [
dependencies = [
"bytes",
"bytes",
"fnv",
"fnv",
...
@@ -967,9 +952,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
...
@@ -967,9 +952,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]]
[[package]]
name = "libc"
name = "libc"
version = "0.2.13
5
"
version = "0.2.13
7
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
68783febc7782c6c5cb401fbda4de5a9898be1762314da0bb2c10ced61f18b0c
"
checksum = "
fc7fcc620a3bff7cdd7a365be3376c97191aeaccc2a603e600951e452615bf89
"
[[package]]
[[package]]
name = "lock_api"
name = "lock_api"
...
@@ -992,9 +977,9 @@ dependencies = [
...
@@ -992,9 +977,9 @@ dependencies = [
[[package]]
[[package]]
name = "macro_rules_attribute"
name = "macro_rules_attribute"
version = "0.1.
2
"
version = "0.1.
3
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
258c86475e1616d6f2d8f5227cfaabd3dae1f6d5388b9597df8a199d4497aba7
"
checksum = "
cf0c9b980bf4f3a37fd7b1c066941dd1b1d0152ce6ee6e8fe8c49b9f6810d862
"
dependencies = [
dependencies = [
"macro_rules_attribute-proc_macro",
"macro_rules_attribute-proc_macro",
"paste",
"paste",
...
@@ -1002,9 +987,9 @@ dependencies = [
...
@@ -1002,9 +987,9 @@ dependencies = [
[[package]]
[[package]]
name = "macro_rules_attribute-proc_macro"
name = "macro_rules_attribute-proc_macro"
version = "0.1.
2
"
version = "0.1.
3
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
f26a8d2502d5aa4d411ef494ba7470eb299f05725179ce3b5de77aa01a9ffdea
"
checksum = "
58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d
"
[[package]]
[[package]]
name = "matchit"
name = "matchit"
...
@@ -1050,14 +1035,14 @@ dependencies = [
...
@@ -1050,14 +1035,14 @@ dependencies = [
[[package]]
[[package]]
name = "mio"
name = "mio"
version = "0.8.
4
"
version = "0.8.
5
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
57ee1c23c7c63b0c9250c339ffdc69255f110b298b901b9f6c82547b7b87caaf
"
checksum = "
e5d732bc30207a6423068df043e3d02e0735b155ad7ce1a6f76fe2baa5b158de
"
dependencies = [
dependencies = [
"libc",
"libc",
"log",
"log",
"wasi 0.11.0+wasi-snapshot-preview1",
"wasi 0.11.0+wasi-snapshot-preview1",
"windows-sys 0.
36.1
",
"windows-sys 0.
42.0
",
]
]
[[package]]
[[package]]
...
@@ -1200,9 +1185,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
...
@@ -1200,9 +1185,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
[[package]]
[[package]]
name = "openssl-sys"
name = "openssl-sys"
version = "0.9.7
6
"
version = "0.9.7
7
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
5230151e44c0f05157effb743e8d517472843121cf9243e8b81393edb5acd9ce
"
checksum = "
b03b84c3b2d099b81f0953422b4d4ad58761589d0229b5506356afca05a3670a
"
dependencies = [
dependencies = [
"autocfg",
"autocfg",
"cc",
"cc",
...
@@ -1213,9 +1198,9 @@ dependencies = [
...
@@ -1213,9 +1198,9 @@ dependencies = [
[[package]]
[[package]]
name = "os_str_bytes"
name = "os_str_bytes"
version = "6.3.
0
"
version = "6.3.
1
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
9ff7415e9ae3fff1225851df9e0d9e4e5479f947619774677a63572e55e80eff
"
checksum = "
3baf96e39c5359d2eb0dd6ccb42c62b91d9678aa68160d261b9e0ccbf9e9dea9
"
[[package]]
[[package]]
name = "overload"
name = "overload"
...
@@ -1302,9 +1287,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
...
@@ -1302,9 +1287,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
[[package]]
name = "pkg-config"
name = "pkg-config"
version = "0.3.2
5
"
version = "0.3.2
6
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae
"
checksum = "
6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160
"
[[package]]
[[package]]
name = "ppv-lite86"
name = "ppv-lite86"
...
@@ -1602,18 +1587,18 @@ dependencies = [
...
@@ -1602,18 +1587,18 @@ dependencies = [
[[package]]
[[package]]
name = "serde"
name = "serde"
version = "1.0.14
5
"
version = "1.0.14
7
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
728eb6351430bccb993660dfffc5a72f91ccc1295abaa8ce19b27ebe4f75568b
"
checksum = "
d193d69bae983fc11a79df82342761dfbf28a99fc8d203dca4c3c1b590948965
"
dependencies = [
dependencies = [
"serde_derive",
"serde_derive",
]
]
[[package]]
[[package]]
name = "serde_derive"
name = "serde_derive"
version = "1.0.14
5
"
version = "1.0.14
7
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
81fa1584d3d1bcacd84c277a0dfe21f5b0f6accf4a23d04d4c6d61f1af522b4c
"
checksum = "
4f1d362ca8fc9c3e3a7484440752472d68a6caa98f1ab81d99b5dfe517cec852
"
dependencies = [
dependencies = [
"proc-macro2",
"proc-macro2",
"quote",
"quote",
...
@@ -1622,9 +1607,9 @@ dependencies = [
...
@@ -1622,9 +1607,9 @@ dependencies = [
[[package]]
[[package]]
name = "serde_json"
name = "serde_json"
version = "1.0.8
6
"
version = "1.0.8
7
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
41feea4228a6f1cd09ec7a3593a682276702cd67b5273544757dae23c096f074
"
checksum = "
6ce777b7b150d76b9cf60d28b55f5847135a003f7d7350c6be7a773508ce7d45
"
dependencies = [
dependencies = [
"itoa",
"itoa",
"ryu",
"ryu",
...
@@ -1739,9 +1724,9 @@ dependencies = [
...
@@ -1739,9 +1724,9 @@ dependencies = [
[[package]]
[[package]]
name = "syn"
name = "syn"
version = "1.0.10
2
"
version = "1.0.10
3
"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
3fcd952facd492f9be3ef0d0b7032a6e442ee9b361d4acc2b1d0c4aaa5f613a1
"
checksum = "
a864042229133ada95abf3b54fdc62ef5ccabe9515b64717bcb9a1919e59445d
"
dependencies = [
dependencies = [
"proc-macro2",
"proc-macro2",
"quote",
"quote",
...
@@ -1798,11 +1783,26 @@ dependencies = [
...
@@ -1798,11 +1783,26 @@ dependencies = [
"winapi",
"winapi",
]
]
[[package]]
name = "text-generation-client"
version = "0.1.0"
dependencies = [
"futures",
"prost",
"thiserror",
"tokio",
"tonic",
"tonic-build",
"tower",
"tracing",
"tracing-error",
]
[[package]]
[[package]]
name = "text-generation-launcher"
name = "text-generation-launcher"
version = "0.1.0"
version = "0.1.0"
dependencies = [
dependencies = [
"clap 4.0.1
7
",
"clap 4.0.1
8
",
"ctrlc",
"ctrlc",
"subprocess",
"subprocess",
"tracing",
"tracing",
...
@@ -1814,12 +1814,12 @@ name = "text-generation-router"
...
@@ -1814,12 +1814,12 @@ name = "text-generation-router"
version = "0.1.0"
version = "0.1.0"
dependencies = [
dependencies = [
"axum",
"axum",
"bloom-inference-client",
"clap 4.0.18",
"clap 4.0.17",
"futures",
"futures",
"parking_lot",
"parking_lot",
"serde",
"serde",
"serde_json",
"serde_json",
"text-generation-client",
"thiserror",
"thiserror",
"tokenizers",
"tokenizers",
"tokio",
"tokio",
...
...
Dockerfile
View file @
3cf6368c
...
@@ -66,7 +66,7 @@ COPY proto proto
...
@@ -66,7 +66,7 @@ COPY proto proto
COPY
server server
COPY
server server
RUN
cd
server
&&
\
RUN
cd
server
&&
\
make gen-server
&&
\
make gen-server
&&
\
/opt/miniconda/envs/text-generation/bin/pip
install
.
--no-cache-dir
/opt/miniconda/envs/text-generation/bin/pip
install
".[bnb]"
--no-cache-dir
# Install router
# Install router
COPY
--from=router-builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router
COPY
--from=router-builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router
...
...
Makefile
View file @
3cf6368c
...
@@ -22,7 +22,7 @@ run-bloom-560m-quantize:
...
@@ -22,7 +22,7 @@ run-bloom-560m-quantize:
text-generation-launcher
--model-name
bigscience/bloom-560m
--num-shard
2
--quantize
text-generation-launcher
--model-name
bigscience/bloom-560m
--num-shard
2
--quantize
download-bloom
:
download-bloom
:
bloom-inference
-server download-weights bigscience/bloom
text-generation
-server download-weights bigscience/bloom
run-bloom
:
run-bloom
:
text-generation-launcher
--model-name
bigscience/bloom
--num-shard
8
text-generation-launcher
--model-name
bigscience/bloom
--num-shard
8
...
...
README.md
View file @
3cf6368c
...
@@ -15,11 +15,13 @@ A Rust and gRPC server for large language models text generation inference.
...
@@ -15,11 +15,13 @@ A Rust and gRPC server for large language models text generation inference.
-
[
Safetensors
](
https://github.com/huggingface/safetensors
)
weight loading
-
[
Safetensors
](
https://github.com/huggingface/safetensors
)
weight loading
-
45ms per token generation for BLOOM with 8xA100 80GB
-
45ms per token generation for BLOOM with 8xA100 80GB
##
S
upported models
##
Officially s
upported models
-
BLOOM
-
BLOOM
-
BLOOM-560m
-
BLOOM-560m
Other models are supported on a best-effort basis using
`AutoModelForCausalLM.from_pretrained(<model>, torch_dtype=torch.float16, device_map="auto")`
.
## Load Tests for BLOOM
## Load Tests for BLOOM
See
`k6/load_test.js`
See
`k6/load_test.js`
...
...
aml/model.yaml
View file @
3cf6368c
$schema
:
https://azuremlschemas.azureedge.net/latest/model.schema.json
$schema
:
https://azuremlschemas.azureedge.net/latest/model.schema.json
name
:
bloom
name
:
bloom
-safetensors
version
:
1
version
:
1
path
:
./bloom
path
:
./bloom
-safetensors
type
:
custom_model
type
:
custom_model
launcher/src/main.rs
View file @
3cf6368c
...
@@ -256,7 +256,7 @@ fn shard_manager(
...
@@ -256,7 +256,7 @@ fn shard_manager(
// Process args
// Process args
let
mut
shard_argv
=
vec!
[
let
mut
shard_argv
=
vec!
[
"
bloom-inference
-server"
.to_string
(),
"
text-generation
-server"
.to_string
(),
"serve"
.to_string
(),
"serve"
.to_string
(),
model_name
,
model_name
,
"--uds-path"
.to_string
(),
"--uds-path"
.to_string
(),
...
@@ -311,7 +311,7 @@ fn shard_manager(
...
@@ -311,7 +311,7 @@ fn shard_manager(
Err
(
err
)
=>
{
Err
(
err
)
=>
{
if
let
PopenError
::
IoError
(
ref
err
)
=
err
{
if
let
PopenError
::
IoError
(
ref
err
)
=
err
{
if
err
.kind
()
==
io
::
ErrorKind
::
NotFound
{
if
err
.kind
()
==
io
::
ErrorKind
::
NotFound
{
tracing
::
error!
(
"
bloom-inference
-server not found in PATH"
);
tracing
::
error!
(
"
text-generation
-server not found in PATH"
);
tracing
::
error!
(
"Please install it with `make install-server`"
)
tracing
::
error!
(
"Please install it with `make install-server`"
)
}
}
}
}
...
...
router/Cargo.toml
View file @
3cf6368c
...
@@ -14,7 +14,7 @@ path = "src/main.rs"
...
@@ -14,7 +14,7 @@ path = "src/main.rs"
[dependencies]
[dependencies]
axum
=
{
version
=
"0.5.16"
,
features
=
[
"json"
,
"serde_json"
]
}
axum
=
{
version
=
"0.5.16"
,
features
=
[
"json"
,
"serde_json"
]
}
bloom-inference
-client
=
{
path
=
"client"
}
text-generation
-client
=
{
path
=
"client"
}
clap
=
{
version
=
"4.0.15"
,
features
=
[
"derive"
,
"env"
]
}
clap
=
{
version
=
"4.0.15"
,
features
=
[
"derive"
,
"env"
]
}
futures
=
"0.3.24"
futures
=
"0.3.24"
parking_lot
=
"0.12.1"
parking_lot
=
"0.12.1"
...
...
router/client/Cargo.toml
View file @
3cf6368c
[package]
[package]
name
=
"
bloom-inference
-client"
name
=
"
text-generation
-client"
version
=
"0.1.0"
version
=
"0.1.0"
edition
=
"2021"
edition
=
"2021"
...
...
router/src/batcher.rs
View file @
3cf6368c
...
@@ -3,9 +3,9 @@ use crate::{Db, Entry};
...
@@ -3,9 +3,9 @@ use crate::{Db, Entry};
use
crate
::{
ErrorResponse
,
GenerateRequest
};
use
crate
::{
ErrorResponse
,
GenerateRequest
};
use
axum
::
http
::
StatusCode
;
use
axum
::
http
::
StatusCode
;
use
axum
::
Json
;
use
axum
::
Json
;
use
bloom_inference_client
::{
Batch
,
ClientError
,
GeneratedText
,
ShardedClient
};
use
std
::
future
::
Future
;
use
std
::
future
::
Future
;
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
use
text_generation_client
::{
Batch
,
ClientError
,
GeneratedText
,
ShardedClient
};
use
thiserror
::
Error
;
use
thiserror
::
Error
;
use
tokio
::
sync
::{
oneshot
,
Notify
};
use
tokio
::
sync
::{
oneshot
,
Notify
};
use
tokio
::
time
::
Instant
;
use
tokio
::
time
::
Instant
;
...
...
router/src/db.rs
View file @
3cf6368c
use
crate
::
InferResponse
;
use
crate
::
InferResponse
;
/// This code is massively inspired by Tokio mini-redis
/// This code is massively inspired by Tokio mini-redis
use
crate
::{
GenerateParameters
,
GenerateRequest
};
use
crate
::{
GenerateParameters
,
GenerateRequest
};
use
bloom_inference_client
::{
Batch
,
ClientError
,
LogitsWarperParameters
,
Request
};
use
parking_lot
::
Mutex
;
use
parking_lot
::
Mutex
;
use
std
::
collections
::
BTreeMap
;
use
std
::
collections
::
BTreeMap
;
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
use
text_generation_client
::{
Batch
,
ClientError
,
LogitsWarperParameters
,
Request
};
use
tokio
::
sync
::
oneshot
::
Sender
;
use
tokio
::
sync
::
oneshot
::
Sender
;
use
tokio
::
time
::
Instant
;
use
tokio
::
time
::
Instant
;
...
...
router/src/main.rs
View file @
3cf6368c
/// Text Generation Inference webserver entrypoint
use
bloom_inference_client
::
ShardedClient
;
use
clap
::
Parser
;
use
clap
::
Parser
;
use
std
::
net
::{
IpAddr
,
Ipv4Addr
,
SocketAddr
};
use
std
::
net
::{
IpAddr
,
Ipv4Addr
,
SocketAddr
};
/// Text Generation Inference webserver entrypoint
use
text_generation_client
::
ShardedClient
;
use
text_generation_router
::
server
;
use
text_generation_router
::
server
;
use
tokenizers
::
Tokenizer
;
use
tokenizers
::
Tokenizer
;
...
@@ -19,7 +19,7 @@ struct Args {
...
@@ -19,7 +19,7 @@ struct Args {
max_waiting_tokens
:
usize
,
max_waiting_tokens
:
usize
,
#[clap(default_value
=
"3000"
,
long,
short,
env)]
#[clap(default_value
=
"3000"
,
long,
short,
env)]
port
:
u16
,
port
:
u16
,
#[clap(default_value
=
"/tmp/
bloom-inference
-0"
,
long,
env)]
#[clap(default_value
=
"/tmp/
text-generation
-0"
,
long,
env)]
master_shard_uds_path
:
String
,
master_shard_uds_path
:
String
,
#[clap(default_value
=
"bigscience/bloom"
,
long,
env)]
#[clap(default_value
=
"bigscience/bloom"
,
long,
env)]
tokenizer_name
:
String
,
tokenizer_name
:
String
,
...
...
router/src/server.rs
View file @
3cf6368c
...
@@ -6,9 +6,9 @@ use axum::http::{HeaderMap, StatusCode};
...
@@ -6,9 +6,9 @@ use axum::http::{HeaderMap, StatusCode};
use
axum
::
response
::
IntoResponse
;
use
axum
::
response
::
IntoResponse
;
use
axum
::
routing
::{
get
,
post
};
use
axum
::
routing
::{
get
,
post
};
use
axum
::{
Json
,
Router
};
use
axum
::{
Json
,
Router
};
use
bloom_inference_client
::
ShardedClient
;
use
std
::
net
::
SocketAddr
;
use
std
::
net
::
SocketAddr
;
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
use
text_generation_client
::
ShardedClient
;
use
tokenizers
::
Tokenizer
;
use
tokenizers
::
Tokenizer
;
use
tokio
::
signal
;
use
tokio
::
signal
;
use
tokio
::
sync
::
Semaphore
;
use
tokio
::
sync
::
Semaphore
;
...
...
server/.gitignore
View file @
3cf6368c
# Byte-compiled / optimized / DLL files
# Byte-compiled / optimized / DLL files
__pycache__/
__pycache__/
bloom_inference
/__pycache__/
text_generation
/__pycache__/
bloom_inference
/pb/__pycache__/
text_generation
/pb/__pycache__/
*.py[cod]
*.py[cod]
*$py.class
*$py.class
...
...
server/Makefile
View file @
3cf6368c
gen-server
:
gen-server
:
# Compile protos
# Compile protos
pip
install
grpcio-tools
==
1.49.1
--no-cache-dir
pip
install
grpcio-tools
==
1.49.1
--no-cache-dir
mkdir
bloom_inference
/pb
||
true
mkdir
text_generation
/pb
||
true
python
-m
grpc_tools.protoc
-I
../proto
--python_out
=
bloom_inference
/pb
--grpc_python_out
=
bloom_inference
/pb ../proto/generate.proto
python
-m
grpc_tools.protoc
-I
../proto
--python_out
=
text_generation
/pb
--grpc_python_out
=
text_generation
/pb ../proto/generate.proto
find
bloom_inference
/pb/
-type
f
-name
"*.py"
-print0
-exec
sed
-i
-e
's/^\(import.*pb2\)/from . \1/g'
{}
\;
find
text_generation
/pb/
-type
f
-name
"*.py"
-print0
-exec
sed
-i
-e
's/^\(import.*pb2\)/from . \1/g'
{}
\;
touch
bloom_inference
/pb/__init__.py
touch
text_generation
/pb/__init__.py
install-transformers
:
install-transformers
:
# Install specific version of transformers
# Install specific version of transformers
...
@@ -36,4 +36,4 @@ install: gen-server install-torch install-transformers install-safetensors
...
@@ -36,4 +36,4 @@ install: gen-server install-torch install-transformers install-safetensors
pip
install
-e
.
--no-cache-dir
pip
install
-e
.
--no-cache-dir
run-dev
:
run-dev
:
python
-m
torch.distributed.run
--nproc_per_node
=
2 bloom_inference/cli.py serve bigscience/bloom-560m
--sharded
python
-m
torch.distributed.run
--nproc_per_node
=
2 text_generation/cli.py serve bigscience/bloom-560m
--sharded
\ No newline at end of file
\ No newline at end of file
server/pyproject.toml
View file @
3cf6368c
[tool.poetry]
[tool.poetry]
name
=
"
bloom-inference
"
name
=
"
text-generation
"
version
=
"0.1.0"
version
=
"0.1.0"
description
=
"BLOOM Inference Python gRPC Server"
description
=
"BLOOM Inference Python gRPC Server"
authors
=
[
"Olivier Dehaene <olivier@huggingface.co>"
]
authors
=
[
"Olivier Dehaene <olivier@huggingface.co>"
]
[tool.poetry.scripts]
[tool.poetry.scripts]
bloom-inference-server
=
'bloom_inference
.cli:app'
text-generation-server
=
'text_generation
.cli:app'
[tool.poetry.dependencies]
[tool.poetry.dependencies]
python
=
"^3.9"
python
=
"^3.9"
...
@@ -17,6 +17,9 @@ accelerate = "^0.12.0"
...
@@ -17,6 +17,9 @@ accelerate = "^0.12.0"
joblib
=
"^1.2.0"
joblib
=
"^1.2.0"
bitsandbytes
=
"^0.35.1"
bitsandbytes
=
"^0.35.1"
[tool.poetry.extras]
bnb
=
["bitsandbytes"]
[tool.poetry.group.dev.dependencies]
[tool.poetry.group.dev.dependencies]
grpcio-tools
=
"^1.49.1"
grpcio-tools
=
"^1.49.1"
...
...
server/
bloom_inference
/__init__.py
→
server/
text_generation
/__init__.py
View file @
3cf6368c
File moved
server/
bloom_inference
/cache.py
→
server/
text_generation
/cache.py
View file @
3cf6368c
from
bloom_inference.model
import
Batch
from
typing
import
Dict
,
Optional
from
typing
import
Dict
,
Optional
from
text_generation.models.types
import
Batch
class
Cache
:
class
Cache
:
def
__init__
(
self
):
def
__init__
(
self
):
...
...
server/
bloom_inference
/cli.py
→
server/
text_generation
/cli.py
View file @
3cf6368c
...
@@ -3,7 +3,7 @@ import typer
...
@@ -3,7 +3,7 @@ import typer
from
pathlib
import
Path
from
pathlib
import
Path
from
bloom_inference
import
server
,
utils
from
text_generation
import
server
,
utils
app
=
typer
.
Typer
()
app
=
typer
.
Typer
()
...
@@ -13,7 +13,7 @@ def serve(
...
@@ -13,7 +13,7 @@ def serve(
model_name
:
str
,
model_name
:
str
,
sharded
:
bool
=
False
,
sharded
:
bool
=
False
,
quantize
:
bool
=
False
,
quantize
:
bool
=
False
,
uds_path
:
Path
=
"/tmp/
bloom-inference
"
,
uds_path
:
Path
=
"/tmp/
text-generation
"
,
):
):
if
sharded
:
if
sharded
:
assert
(
assert
(
...
@@ -35,8 +35,9 @@ def serve(
...
@@ -35,8 +35,9 @@ def serve(
@
app
.
command
()
@
app
.
command
()
def
download_weights
(
def
download_weights
(
model_name
:
str
,
model_name
:
str
,
extension
:
str
=
".safetensors"
,
):
):
utils
.
download_weights
(
model_name
)
utils
.
download_weights
(
model_name
,
extension
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
server/text_generation/models/__init__.py
0 → 100644
View file @
3cf6368c
from
text_generation.models.model
import
Model
from
text_generation.models.bloom
import
BLOOMSharded
__all__
=
[
"Model"
,
"BLOOMSharded"
]
def
get_model
(
model_name
:
str
,
sharded
:
bool
,
quantize
:
bool
)
->
Model
:
if
model_name
.
startswith
(
"bigscience/bloom"
):
if
sharded
:
return
BLOOMSharded
(
model_name
,
quantize
)
else
:
if
quantize
:
raise
ValueError
(
"quantization is not supported for non-sharded BLOOM"
)
return
Model
(
model_name
)
else
:
if
sharded
:
raise
ValueError
(
"sharded is only supported for BLOOM"
)
if
quantize
:
raise
ValueError
(
"Quantization is only supported for BLOOM models"
)
return
Model
(
model_name
)
server/
bloom_inference/model
.py
→
server/
text_generation/models/bloom
.py
View file @
3cf6368c
import
torch
import
torch
import
torch.distributed
import
torch.distributed
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
from
typing
import
List
,
Tuple
,
Optional
,
Dict
from
accelerate
import
init_empty_weights
from
accelerate
import
init_empty_weights
from
safetensors
import
safe_open
from
safetensors
import
safe_open
...
@@ -13,10 +12,8 @@ from transformers.models.bloom.parallel_layers import (
...
@@ -13,10 +12,8 @@ from transformers.models.bloom.parallel_layers import (
TensorParallelRowLinear
,
TensorParallelRowLinear
,
)
)
from
bloom_inference.pb
import
generate_pb2
from
text_generation.models
import
Model
from
bloom_inference.utils
import
(
from
text_generation.utils
import
(
StoppingCriteria
,
NextTokenChooser
,
initialize_torch_distributed
,
initialize_torch_distributed
,
weight_files
,
weight_files
,
download_weights
,
download_weights
,
...
@@ -32,359 +29,9 @@ except Exception as e:
...
@@ -32,359 +29,9 @@ except Exception as e:
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
@
dataclass
class
BLOOMSharded
(
Model
):
class
Batch
:
batch_id
:
int
requests
:
List
[
generate_pb2
.
Request
]
all_input_lengths
:
List
[
int
]
input_ids
:
Dict
[
str
,
torch
.
Tensor
]
all_input_ids
:
List
[
torch
.
Tensor
]
next_token_choosers
:
List
[
NextTokenChooser
]
stopping_criterias
:
List
[
StoppingCriteria
]
size
:
int
max_sequence_length
:
int
def
to_pb
(
self
):
return
generate_pb2
.
Batch
(
id
=
self
.
batch_id
,
requests
=
self
.
requests
,
size
=
self
.
size
,
max_sequence_length
=
self
.
max_sequence_length
,
)
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
AutoTokenizer
,
device
:
torch
.
device
)
->
"Batch"
:
inputs
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
all_input_lengths
=
[]
# Parse batch
for
r
in
pb
.
requests
:
inputs
.
append
(
r
.
inputs
)
all_input_lengths
.
append
(
r
.
input_length
)
next_token_choosers
.
append
(
NextTokenChooser
(
temperature
=
r
.
parameters
.
temperature
,
top_k
=
r
.
parameters
.
top_k
,
top_p
=
r
.
parameters
.
top_p
,
do_sample
=
r
.
parameters
.
do_sample
,
)
)
stopping_criterias
.
append
(
StoppingCriteria
(
max_new_tokens
=
r
.
max_new_tokens
))
input_ids
=
tokenizer
(
inputs
,
return_tensors
=
"pt"
,
padding
=
True
,
pad_to_multiple_of
=
8
).
to
(
device
)
all_input_ids
=
input_ids
[
"input_ids"
].
unsqueeze
(
-
1
)
return
cls
(
batch_id
=
pb
.
id
,
requests
=
pb
.
requests
,
all_input_lengths
=
all_input_lengths
,
input_ids
=
input_ids
,
all_input_ids
=
all_input_ids
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
pb
.
size
,
max_sequence_length
=
pb
.
max_sequence_length
,
)
@
classmethod
def
concatenate
(
cls
,
batches
:
List
[
"Batch"
])
->
"Batch"
:
# Used for padding
total_batch_size
=
sum
(
batch
.
size
for
batch
in
batches
)
max_sequence_length
=
max
(
batch
.
max_sequence_length
for
batch
in
batches
)
# Batch attributes
input_ids
=
{
"input_ids"
:
None
,
"attention_mask"
:
None
,
"past_key_values"
:
[]}
requests
=
[]
all_input_lengths
=
[]
all_input_ids
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index
=
0
for
i
,
batch
in
enumerate
(
batches
):
requests
.
extend
(
batch
.
requests
)
all_input_lengths
.
extend
(
batch
.
all_input_lengths
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
# Slicing end index for this batch
end_index
=
start_index
+
batch
.
size
# We only concatenate batches that did at least one step
if
batch
.
input_ids
[
"input_ids"
].
shape
[
1
]
>
1
:
raise
ValueError
(
"Batch input_ids should be of shape (batch_size, 1)"
)
# Initialize tensors
if
i
==
0
:
input_ids
[
"input_ids"
]
=
torch
.
empty
(
(
total_batch_size
,
1
),
dtype
=
batch
.
input_ids
[
"input_ids"
].
dtype
,
device
=
batch
.
input_ids
[
"input_ids"
].
device
,
)
input_ids
[
"attention_mask"
]
=
torch
.
zeros
(
(
total_batch_size
,
max_sequence_length
),
dtype
=
batch
.
input_ids
[
"attention_mask"
].
dtype
,
device
=
batch
.
input_ids
[
"attention_mask"
].
device
,
)
# input_ids["input_ids"] is always of shape [batch_size, 1]
# We do not need to pad it
input_ids
[
"input_ids"
][
start_index
:
end_index
]
=
batch
.
input_ids
[
"input_ids"
]
# We need to slice the attention mask to remove padding from previous steps
input_ids
[
"attention_mask"
][
start_index
:
end_index
,
-
batch
.
max_sequence_length
:
]
=
batch
.
input_ids
[
"attention_mask"
][:,
-
batch
.
max_sequence_length
:]
for
j
,
past
in
enumerate
(
batch
.
input_ids
[
"past_key_values"
]):
past_keys
=
past
[
0
]
past_values
=
past
[
1
]
_
,
head_dim
,
padded_sequence_length
=
past_keys
.
shape
# Reshape the tensors to make slicing easier
past_keys
=
past_keys
.
view
(
batch
.
size
,
-
1
,
head_dim
,
padded_sequence_length
)
past_values
=
past_values
.
view
(
batch
.
size
,
-
1
,
padded_sequence_length
,
head_dim
)
num_heads
=
past_keys
.
shape
[
1
]
# Initialize tensors
# This will run only once per layer
if
j
==
len
(
input_ids
[
"past_key_values"
]):
padded_past_keys
=
torch
.
zeros
(
(
total_batch_size
,
num_heads
,
head_dim
,
max_sequence_length
-
1
,
),
dtype
=
past_keys
.
dtype
,
device
=
past_keys
.
device
,
)
padded_past_values
=
torch
.
zeros
(
(
total_batch_size
,
num_heads
,
max_sequence_length
-
1
,
head_dim
,
),
dtype
=
past_values
.
dtype
,
device
=
past_values
.
device
,
)
input_ids
[
"past_key_values"
].
append
(
[
padded_past_keys
,
padded_past_values
]
)
# We slice the past keys and values to remove the padding from previous batches
input_ids
[
"past_key_values"
][
j
][
0
][
start_index
:
end_index
,
:,
:,
-
(
batch
.
max_sequence_length
-
1
)
:
]
=
past_keys
[:,
:,
:,
-
(
batch
.
max_sequence_length
-
1
)
:]
input_ids
[
"past_key_values"
][
j
][
1
][
start_index
:
end_index
,
:,
-
(
batch
.
max_sequence_length
-
1
)
:,
:
]
=
past_values
[:,
:,
-
(
batch
.
max_sequence_length
-
1
)
:,
:]
# If we are on the last batch, we need to reshape the tensors
if
(
i
+
1
)
==
len
(
batches
):
input_ids
[
"past_key_values"
][
j
][
0
]
=
input_ids
[
"past_key_values"
][
j
][
0
].
view
(
total_batch_size
*
num_heads
,
head_dim
,
-
1
)
input_ids
[
"past_key_values"
][
j
][
1
]
=
input_ids
[
"past_key_values"
][
j
][
1
].
view
(
total_batch_size
*
num_heads
,
-
1
,
head_dim
)
start_index
+=
batch
.
size
return
cls
(
batch_id
=
batches
[
0
].
batch_id
,
requests
=
requests
,
all_input_lengths
=
all_input_lengths
,
input_ids
=
input_ids
,
all_input_ids
=
all_input_ids
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
total_batch_size
,
max_sequence_length
=
max_sequence_length
,
)
@
dataclass
class
GeneratedText
:
request
:
generate_pb2
.
Request
output
:
str
def
to_pb
(
self
)
->
generate_pb2
.
GeneratedText
:
return
generate_pb2
.
GeneratedText
(
request
=
self
.
request
,
output
=
self
.
output
)
class
BLOOM
:
def
__init__
(
self
,
model_name
:
str
):
if
torch
.
cuda
.
is_available
():
self
.
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
else
:
self
.
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
padding_side
=
"left"
)
self
.
model
=
(
AutoModelForCausalLM
.
from_pretrained
(
model_name
)
.
eval
()
.
to
(
self
.
device
)
.
to
(
dtype
)
)
self
.
num_heads
=
self
.
model
.
base_model
.
num_heads
def
forward
(
self
,
input_ids
,
attention_mask
,
past_key_values
:
Optional
=
None
):
# Model Forward
return
self
.
model
.
forward
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
use_cache
=
True
,
)
def
generate_token
(
self
,
batch
:
Batch
)
->
Tuple
[
List
[
GeneratedText
],
Optional
[
Batch
]]:
with
torch
.
inference_mode
():
outputs
=
self
.
forward
(
**
batch
.
input_ids
)
# List of indices to cache
next_batch_keep_indices
=
[]
next_batch_past_keep_indices
=
[]
# New input_ids for next forward
next_batch_input_ids
=
[]
next_batch_all_input_ids
=
[]
next_all_input_lengths
=
[]
next_batch_size
=
0
next_batch_max_sequence_length
=
0
# Finished requests
generated_texts
:
List
[
GeneratedText
]
=
[]
# Zipped iterator
iterator
=
zip
(
batch
.
requests
,
batch
.
all_input_lengths
,
outputs
.
logits
,
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
batch
.
all_input_ids
,
)
# For each member of the batch
for
i
,
(
request
,
input_length
,
logits
,
next_token_chooser
,
stopping_criteria
,
all_tokens
,
)
in
enumerate
(
iterator
):
# Select next token
next_token
=
next_token_chooser
(
all_tokens
,
logits
.
unsqueeze
(
0
)[:,
-
1
])
# Append next token to all tokens
all_tokens
=
torch
.
cat
([
all_tokens
,
next_token
])
# Evaluate stopping criteria
if
stopping_criteria
(
all_tokens
):
# Decode all tokens
output
=
self
.
tokenizer
.
decode
(
all_tokens
.
squeeze
(
-
1
),
skip_special_tokens
=
True
)
# Add to the list of finished generations with the original request
generated_texts
.
append
(
GeneratedText
(
request
,
output
))
# add to the next batch
else
:
next_batch_keep_indices
.
append
(
i
)
# past_key_values is of shape [batch_size * num_heads, ...]
# so we need to take into account the `num_heads` stride here
next_batch_past_keep_indices
.
extend
(
[
j
for
j
in
range
(
i
*
self
.
num_heads
,
(
i
+
1
)
*
self
.
num_heads
)]
)
next_batch_input_ids
.
append
(
next_token
)
next_batch_all_input_ids
.
append
(
all_tokens
)
next_batch_size
+=
1
new_input_length
=
input_length
+
1
next_all_input_lengths
.
append
(
new_input_length
)
next_batch_max_sequence_length
=
max
(
next_batch_max_sequence_length
,
new_input_length
)
# We finished all generations in the batch; there is no next batch
if
not
next_batch_keep_indices
:
return
generated_texts
,
None
# If we finished at least one generation
next_batch_input_ids
=
{
"input_ids"
:
torch
.
cat
(
next_batch_input_ids
,
dim
=
0
)}
if
generated_texts
:
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids
[
"attention_mask"
]
=
batch
.
input_ids
[
"attention_mask"
][
next_batch_keep_indices
]
next_batch_input_ids
[
"past_key_values"
]
=
[
(
keys
[
next_batch_past_keep_indices
],
values
[
next_batch_past_keep_indices
],
)
for
keys
,
values
in
outputs
[
"past_key_values"
]
]
next_batch_requests
=
[
batch
.
requests
[
i
]
for
i
in
next_batch_keep_indices
]
next_batch_next_token_choosers
=
[
batch
.
next_token_choosers
[
i
]
for
i
in
next_batch_keep_indices
]
next_batch_stopping_criterias
=
[
batch
.
stopping_criterias
[
i
]
for
i
in
next_batch_keep_indices
]
else
:
next_batch_input_ids
[
"attention_mask"
]
=
batch
.
input_ids
[
"attention_mask"
]
next_batch_input_ids
[
"past_key_values"
]
=
outputs
[
"past_key_values"
]
next_batch_requests
=
batch
.
requests
next_batch_next_token_choosers
=
batch
.
next_token_choosers
next_batch_stopping_criterias
=
batch
.
stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
next_batch_input_ids
[
"attention_mask"
]
=
torch
.
cat
(
[
next_batch_input_ids
[
"attention_mask"
],
torch
.
ones
((
next_batch_size
,
1
)).
to
(
self
.
device
),
],
dim
=
1
,
)
next_batch
=
Batch
(
batch_id
=
batch
.
batch_id
,
requests
=
next_batch_requests
,
all_input_lengths
=
next_all_input_lengths
,
input_ids
=
next_batch_input_ids
,
all_input_ids
=
next_batch_all_input_ids
,
next_token_choosers
=
next_batch_next_token_choosers
,
stopping_criterias
=
next_batch_stopping_criterias
,
size
=
next_batch_size
,
max_sequence_length
=
next_batch_max_sequence_length
,
)
return
generated_texts
,
next_batch
class
BLOOMSharded
(
BLOOM
):
def
__init__
(
self
,
model_name
:
str
,
quantize
:
bool
=
False
):
def
__init__
(
self
,
model_name
:
str
,
quantize
:
bool
=
False
):
super
(
BLOOM
,
self
).
__init__
()
super
(
Model
,
self
).
__init__
()
self
.
process_group
,
self
.
rank
,
self
.
world_size
=
initialize_torch_distributed
()
self
.
process_group
,
self
.
rank
,
self
.
world_size
=
initialize_torch_distributed
()
self
.
master
=
self
.
rank
==
0
self
.
master
=
self
.
rank
==
0
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
@@ -411,10 +58,10 @@ class BLOOMSharded(BLOOM):
...
@@ -411,10 +58,10 @@ class BLOOMSharded(BLOOM):
# Only download weights for small models
# Only download weights for small models
if
self
.
master
and
model_name
==
"bigscience/bloom-560m"
:
if
self
.
master
and
model_name
==
"bigscience/bloom-560m"
:
download_weights
(
model_name
)
download_weights
(
model_name
,
extension
=
".safetensors"
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_name
)
filenames
=
weight_files
(
model_name
,
extension
=
".safetensors"
)
with
init_empty_weights
():
with
init_empty_weights
():
model
=
AutoModelForCausalLM
.
from_config
(
config
)
model
=
AutoModelForCausalLM
.
from_config
(
config
)
...
@@ -500,7 +147,9 @@ class BLOOMSharded(BLOOM):
...
@@ -500,7 +147,9 @@ class BLOOMSharded(BLOOM):
if
quantize
:
if
quantize
:
if
not
HAS_BITS_AND_BYTES
:
if
not
HAS_BITS_AND_BYTES
:
raise
ImportError
(
raise
ImportError
(
"bitsandbytes is not available on your machine"
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.
\n
"
"You can install it with `pip install bitsandbytes`."
)
)
if
(
if
(
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment