"vscode:/vscode.git/clone" did not exist on "37b4e136e1824e86644540571af29d87fe528dc7"
Commit f16f2f5a authored by Olivier Dehaene's avatar Olivier Dehaene Committed by OlivierDehaene
Browse files

v0.1.0

parent 92c1ecd0
aml
router/target
\ No newline at end of file
target
\ No newline at end of file
.idea
\ No newline at end of file
.idea
target
\ No newline at end of file
......@@ -55,9 +55,9 @@ dependencies = [
[[package]]
name = "async-trait"
version = "0.1.57"
version = "0.1.58"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76464446b8bc32758d7e88ee1a804d9914cd9b1cb264c029899680b0be29826f"
checksum = "1e805d94e6b5001b651426cf4cd446b1ab5f319d27bab5c644f61de0a804360c"
dependencies = [
"proc-macro2",
"quote",
......@@ -166,9 +166,9 @@ dependencies = [
[[package]]
name = "bumpalo"
version = "3.11.0"
version = "3.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1ad822118d20d2c234f427000d5acc36eabe1e29a348c89b63dd60b13f28e5d"
checksum = "572f695136211188308f16ad2ca5c851a712c464060ae6974944458eb83880ba"
[[package]]
name = "byteorder"
......@@ -255,9 +255,9 @@ dependencies = [
[[package]]
name = "clap"
version = "4.0.15"
version = "4.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6bf8832993da70a4c6d13c581f4463c2bdda27b9bf1c5498dc4365543abe6d6f"
checksum = "06badb543e734a2d6568e19a40af66ed5364360b9226184926f89d229b4b4267"
dependencies = [
"atty",
"bitflags",
......@@ -391,6 +391,16 @@ dependencies = [
"typenum",
]
[[package]]
name = "ctrlc"
version = "3.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d91974fbbe88ec1df0c24a4f00f99583667a7e2e6272b2b92d294d81e462173"
dependencies = [
"nix",
"winapi",
]
[[package]]
name = "darling"
version = "0.10.2"
......@@ -529,7 +539,7 @@ dependencies = [
"cfg-if",
"libc",
"redox_syscall",
"windows-sys",
"windows-sys 0.36.1",
]
[[package]]
......@@ -936,9 +946,9 @@ dependencies = [
[[package]]
name = "itoa"
version = "1.0.3"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c8af84674fe1f223a982c933a0ee1086ac4d4052aa0fb8060c12c6ad838e754"
checksum = "4217ad341ebadf8d8e724e264f13e593e0648f5b3e94b3896a5df283be015ecc"
[[package]]
name = "js-sys"
......@@ -957,9 +967,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]]
name = "libc"
version = "0.2.134"
version = "0.2.135"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "329c933548736bc49fd575ee68c89e8be4d260064184389a5b77517cddd99ffb"
checksum = "68783febc7782c6c5cb401fbda4de5a9898be1762314da0bb2c10ced61f18b0c"
[[package]]
name = "lock_api"
......@@ -1047,7 +1057,7 @@ dependencies = [
"libc",
"log",
"wasi 0.11.0+wasi-snapshot-preview1",
"windows-sys",
"windows-sys 0.36.1",
]
[[package]]
......@@ -1074,6 +1084,18 @@ dependencies = [
"tempfile",
]
[[package]]
name = "nix"
version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e322c04a9e3440c327fca7b6c8a63e6890a32fa2ad689db972425f07e0d22abb"
dependencies = [
"autocfg",
"bitflags",
"cfg-if",
"libc",
]
[[package]]
name = "nom"
version = "7.1.1"
......@@ -1084,6 +1106,16 @@ dependencies = [
"minimal-lexical",
]
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
dependencies = [
"overload",
"winapi",
]
[[package]]
name = "num_cpus"
version = "1.13.1"
......@@ -1185,6 +1217,12 @@ version = "6.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ff7415e9ae3fff1225851df9e0d9e4e5479f947619774677a63572e55e80eff"
[[package]]
name = "overload"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]]
name = "parking_lot"
version = "0.12.1"
......@@ -1197,15 +1235,15 @@ dependencies = [
[[package]]
name = "parking_lot_core"
version = "0.9.3"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09a279cbf25cb0757810394fbc1e359949b59e348145c643a939a525692e6929"
checksum = "4dc9e0dc2adc1c69d09143aff38d3d30c5c3f0df0dad82e6d25547af174ebec0"
dependencies = [
"cfg-if",
"libc",
"redox_syscall",
"smallvec",
"windows-sys",
"windows-sys 0.42.0",
]
[[package]]
......@@ -1300,9 +1338,9 @@ dependencies = [
[[package]]
name = "proc-macro2"
version = "1.0.46"
version = "1.0.47"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94e2ef8dbfc347b10c094890f778ee2e36ca9bb4262e86dc99cd217e35f3470b"
checksum = "5ea3d908b0e36316caf9e9e2c4625cdde190a7e6f440d794667ed17a1855e725"
dependencies = [
"unicode-ident",
]
......@@ -1530,7 +1568,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88d6731146462ea25d9244b2ed5fd1d716d25c52e4d54aa4fb0f3c4e9854dbe2"
dependencies = [
"lazy_static",
"windows-sys",
"windows-sys 0.36.1",
]
[[package]]
......@@ -1584,9 +1622,9 @@ dependencies = [
[[package]]
name = "serde_json"
version = "1.0.85"
version = "1.0.86"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e55a28e3aaef9d5ce0506d0a14dbba8054ddc7e499ef522dd8b26859ec9d4a44"
checksum = "41feea4228a6f1cd09ec7a3593a682276702cd67b5273544757dae23c096f074"
dependencies = [
"itoa",
"ryu",
......@@ -1625,6 +1663,15 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "signal-hook-registry"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0"
dependencies = [
"libc",
]
[[package]]
name = "slab"
version = "0.4.7"
......@@ -1680,11 +1727,21 @@ version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]]
name = "subprocess"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c2e86926081dda636c546d8c5e641661049d7562a68f5488be4a1f7f66f6086"
dependencies = [
"libc",
"winapi",
]
[[package]]
name = "syn"
version = "1.0.101"
version = "1.0.102"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e90cde112c4b9690b8cbe810cba9ddd8bc1d7472e2cae317b69e9438c1cba7d2"
checksum = "3fcd952facd492f9be3ef0d0b7032a6e442ee9b361d4acc2b1d0c4aaa5f613a1"
dependencies = [
"proc-macro2",
"quote",
......@@ -1741,13 +1798,24 @@ dependencies = [
"winapi",
]
[[package]]
name = "text-generation-launcher"
version = "0.1.0"
dependencies = [
"clap 4.0.17",
"ctrlc",
"subprocess",
"tracing",
"tracing-subscriber",
]
[[package]]
name = "text-generation-router"
version = "0.1.0"
dependencies = [
"axum",
"bloom-inference-client",
"clap 4.0.15",
"clap 4.0.17",
"futures",
"parking_lot",
"serde",
......@@ -1872,6 +1940,7 @@ dependencies = [
"num_cpus",
"parking_lot",
"pin-project-lite",
"signal-hook-registry",
"socket2",
"tokio-macros",
"winapi",
......@@ -1910,9 +1979,9 @@ dependencies = [
[[package]]
name = "tokio-stream"
version = "0.1.10"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6edf2d6bc038a43d31353570e27270603f4648d18f5ed10c0e179abe43255af"
checksum = "d660770404473ccd7bc9f8b28494a811bc18542b915c0855c51e8f419d5223ce"
dependencies = [
"futures-core",
"pin-project-lite",
......@@ -2031,9 +2100,9 @@ dependencies = [
[[package]]
name = "tower-layer"
version = "0.3.1"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "343bc9466d3fe6b0f960ef45960509f84480bf4fd96f92901afe7ff3df9d3a62"
checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0"
[[package]]
name = "tower-service"
......@@ -2043,9 +2112,9 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52"
[[package]]
name = "tracing"
version = "0.1.36"
version = "0.1.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2fce9567bd60a67d08a16488756721ba392f24f29006402881e43b19aac64307"
checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8"
dependencies = [
"cfg-if",
"log",
......@@ -2056,9 +2125,9 @@ dependencies = [
[[package]]
name = "tracing-attributes"
version = "0.1.22"
version = "0.1.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11c75893af559bc8e10716548bdef5cb2b983f8e637db9d0e15126b61b484ee2"
checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a"
dependencies = [
"proc-macro2",
"quote",
......@@ -2067,9 +2136,9 @@ dependencies = [
[[package]]
name = "tracing-core"
version = "0.1.29"
version = "0.1.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5aeea4303076558a00714b823f9ad67d58a3bbda1df83d8827d21193156e22f7"
checksum = "24eb03ba0eab1fd845050058ce5e616558e8f8d8fca633e6b163fe25c797213a"
dependencies = [
"once_cell",
"valuable",
......@@ -2108,11 +2177,11 @@ dependencies = [
[[package]]
name = "tracing-subscriber"
version = "0.3.15"
version = "0.3.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60db860322da191b40952ad9affe65ea23e7dd6a5c442c2c42865810c6ab8e6b"
checksum = "a6176eae26dd70d0c919749377897b54a9276bd7061339665dd68777926b5a70"
dependencies = [
"ansi_term",
"nu-ansi-term",
"sharded-slab",
"smallvec",
"thread_local",
......@@ -2140,9 +2209,9 @@ checksum = "099b7128301d285f79ddd55b9a83d5e6b9e97c92e0ea0daebee7263e932de992"
[[package]]
name = "unicode-ident"
version = "1.0.4"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcc811dc4066ac62f84f11307873c4850cb653bfa9b1719cee2bd2204a4bc5dd"
checksum = "6ceab39d59e4c9499d4e5a8ee0e2735b891bb7308ac83dfb4e80cad195c9f6f3"
[[package]]
name = "unicode-normalization"
......@@ -2361,43 +2430,100 @@ version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea04155a16a59f9eab786fe12a4a450e75cdb175f9e0d80da1e17db09f55b8d2"
dependencies = [
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_msvc",
"windows_aarch64_msvc 0.36.1",
"windows_i686_gnu 0.36.1",
"windows_i686_msvc 0.36.1",
"windows_x86_64_gnu 0.36.1",
"windows_x86_64_msvc 0.36.1",
]
[[package]]
name = "windows-sys"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc 0.42.0",
"windows_i686_gnu 0.42.0",
"windows_i686_msvc 0.42.0",
"windows_x86_64_gnu 0.42.0",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc 0.42.0",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41d2aa71f6f0cbe00ae5167d90ef3cfe66527d6f613ca78ac8024c3ccab9a19e"
[[package]]
name = "windows_aarch64_msvc"
version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9bb8c3fd39ade2d67e9874ac4f3db21f0d710bee00fe7cab16949ec184eeaa47"
[[package]]
name = "windows_aarch64_msvc"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd0f252f5a35cac83d6311b2e795981f5ee6e67eb1f9a7f64eb4500fbc4dcdb4"
[[package]]
name = "windows_i686_gnu"
version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "180e6ccf01daf4c426b846dfc66db1fc518f074baa793aa7d9b9aaeffad6a3b6"
[[package]]
name = "windows_i686_gnu"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbeae19f6716841636c28d695375df17562ca208b2b7d0dc47635a50ae6c5de7"
[[package]]
name = "windows_i686_msvc"
version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2e7917148b2812d1eeafaeb22a97e4813dfa60a3f8f78ebe204bcc88f12f024"
[[package]]
name = "windows_i686_msvc"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "84c12f65daa39dd2babe6e442988fc329d6243fdce47d7d2d155b8d874862246"
[[package]]
name = "windows_x86_64_gnu"
version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4dcd171b8776c41b97521e5da127a2d86ad280114807d0b2ab1e462bc764d9e1"
[[package]]
name = "windows_x86_64_gnu"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf7b1b21b5362cbc318f686150e5bcea75ecedc74dd157d874d754a2ca44b0ed"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09d525d2ba30eeb3297665bd434a54297e4170c7f1a44cad4ef58095b4cd2028"
[[package]]
name = "windows_x86_64_msvc"
version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680"
[[package]]
name = "windows_x86_64_msvc"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f40009d85759725a34da6d89a94e63d7bdc50a862acf0dbc7c8e488f1edcb6f5"
[[package]]
name = "winreg"
version = "0.10.1"
......
[workspace]
members = [
"router",
"router/client",
"launcher"
]
[profile.release]
debug = 1
incremental = true
lto = "off"
FROM rust:1.64 as builder
FROM rust:1.64 as router-builder
WORKDIR /usr/src
......@@ -9,7 +9,17 @@ WORKDIR /usr/src/router
RUN cargo install --path .
FROM nvidia/cuda:11.6.1-devel-ubuntu18.04
FROM rust:1.64 as launcher-builder
WORKDIR /usr/src
COPY launcher launcher
WORKDIR /usr/src/launcher
RUN cargo install --path .
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
ENV LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \
......@@ -34,17 +44,15 @@ RUN cd ~ && \
bash ./Miniconda3-latest-Linux-x86_64.sh -bf -p /opt/miniconda && \
conda create -n text-generation python=3.9 -y
WORKDIR /usr/src
COPY server/Makefile server/Makefile
# Install specific version of torch
RUN /opt/miniconda/envs/text-generation/bin/pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir
RUN cd server && make install-torch
# Install specific version of transformers
RUN wget https://github.com/huggingface/transformers/archive/46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
unzip 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
rm 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
cd transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921 && \
/opt/miniconda/envs/text-generation/bin/python setup.py install
WORKDIR /usr/src
RUN cd server && make install-transformers
# Install server
COPY server server
......@@ -52,9 +60,7 @@ RUN cd server && \
/opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir
# Install router
COPY --from=builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router
COPY run.sh .
RUN chmod +x run.sh
COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router
COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher
CMD ["./run.sh"]
\ No newline at end of file
CMD text-generation-launcher --model-name $MODEL_NAME --num-shard $NUM_GPUS --shard-directory $MODEL_BASE_PATH
\ No newline at end of file
install-server:
cd server && make pip-install
install-router:
cd router && cargo install --path .
install-launcher:
cd launcher && cargo install --path .
install:
make install-server
make install-router
make install-launcher
run-bloom-560m:
text-generation-launcher --model-name bigscience/bloom-560m --shard-directory /tmp/models --num-shard 2
run-bloom:
text-generation-launcher --model-name bigscience/bloom --shard-directory /tmp/models --num-shard 8
# Text Generation Inference
# LLM Text Generation Inference
A Rust and gRPC server for text generation inference.
<div align="center">
## Load Tests
![architecture](assets/architecture.jpg)
</div>
A Rust and gRPC server for large language models text generation inference.
## Load Tests for BLOOM
See `k6/load_test.js`
We send the default examples with a 1 second delay between each request.
We send the default examples with a 1 second delay between requests.
Stages:
- Ramp up to 50 concurrent requests per second in 1min
- Ramp up from 50 to 100 concurrent requests per second in 2min
- Ramp down to 0 concurrent requests per second in 1min
- Ramp up to 50 vus in 1min
- Ramp up from 50 to 100 vus in 2min
- Ramp down to 0 vus in 1min
| | avg | min | med | max | p(90) | p(95) | RPS |
|------------------------|-----------|-----------|-----------|------------|-----------|-----------|----------|
| Original code | 8.9s | 1s | 9.12s | 16.69s | 13.7s | 14.26s | 5.9 |
| ISO with original code | 8.88s | 959.53ms | 8.89s | 17.08s | 13.34s | 14.12s | 5.94 |
| New batching logic | **5.44s** | **1.27s** | **5.28s** | **13.12s** | **7.78s** | **8.92s** | **9.08** |
| | avg | min | med | max | p(90) | p(95) | RPS |
|--------------------------------------------------------------|-----------|--------------|-----------|------------|-----------|-----------|----------|
| [Original code](https://github.com/huggingface/transformers_bloom_parallel) | 8.9s | 1s | 9.12s | 16.69s | 13.7s | 14.26s | 5.9 |
| ISO with original code | 8.88s | **959.53ms** | 8.89s | 17.08s | 13.34s | 14.12s | 5.94 |
| New batching logic | **5.44s** | 1.27s | **5.28s** | **13.12s** | **7.78s** | **8.92s** | **9.08** |
## Install
```shell
cd server
pip install .
make install
```
```
cd router
cargo build --release
```
## Run
## Run
```shell
python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-directory /dev/shm/models
make run-bloom-560m
```
## Test
```shell
./router/target/release/router
curl 127.0.0.1:3000/generate \
-X POST \
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
-H 'Content-Type: application/json'
```
## TODO:
- [ ] Add docstrings + comments everywhere as the codebase is fairly complicated
- [ ] Add tests
- [ ] Add shutdown logic in router and server
- [ ] Improve multi-processing logic in server
- [ ] Improve past key layer indexing?
\ No newline at end of file
- [ ] Add tests for the `server/model` logic
\ No newline at end of file
......@@ -8,7 +8,7 @@ environment_variables:
MODEL_NAME: bigscience/bloom
NUM_GPUS: 8
environment:
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.3
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference:0.2
inference_config:
liveness_route:
port: 3000
......
[package]
name = "text-generation-launcher"
version = "0.1.0"
edition = "2021"
authors = ["Olivier Dehaene"]
description = "Text Generation Launcher"
[dependencies]
clap = { version = "4.0.15", features = ["derive", "env"] }
ctrlc = "3.2.3"
subprocess = "0.2.9"
tracing = "0.1.37"
tracing-subscriber = "0.3.16"
use clap::Parser;
use std::io::{BufRead, BufReader, Read};
use std::path::Path;
use std::process::ExitCode;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::TryRecvError;
use std::sync::Arc;
use std::sync::{mpsc, Mutex};
use std::thread;
use std::thread::sleep;
use std::time::{Duration, Instant};
use std::{fs, io};
use subprocess::{Popen, PopenConfig, PopenError, Redirection};
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(default_value = "bigscience/bloom-560m", long, env)]
model_name: String,
#[clap(long, env)]
num_shard: Option<usize>,
#[clap(long, env)]
shard_directory: Option<String>,
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "1000", long, env)]
max_input_length: usize,
#[clap(default_value = "32", long, env)]
max_batch_size: usize,
#[clap(default_value = "5", long, env)]
max_waiting_time: u64,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/text-generation-server", long, env)]
shard_uds_path: String,
#[clap(default_value = "localhost", long, env)]
master_addr: String,
#[clap(default_value = "29500", long, env)]
master_port: usize,
}
fn main() -> ExitCode {
tracing_subscriber::fmt::init();
// Pattern match configuration
let Args {
model_name,
num_shard,
shard_directory,
max_concurrent_requests,
max_input_length,
max_batch_size,
max_waiting_time,
port,
shard_uds_path,
master_addr,
master_port,
} = Args::parse();
// By default we only have one master shard
let num_shard = num_shard.unwrap_or(1);
// Signal handler
let running = Arc::new(AtomicBool::new(true));
let r = running.clone();
ctrlc::set_handler(move || {
r.store(false, Ordering::SeqCst);
})
.expect("Error setting Ctrl-C handler");
// Shared shutdown bool
let shutdown = Arc::new(Mutex::new(false));
// Shared shutdown channel
// When shutting down, the main thread will wait for all senders to be dropped
let (shutdown_sender, shutdown_receiver) = mpsc::channel();
// Shared channel to track shard status
let (status_sender, status_receiver) = mpsc::channel();
// Start shard processes
for rank in 0..num_shard {
let model_name = model_name.clone();
let uds_path = shard_uds_path.clone();
let shard_directory = shard_directory.clone();
let master_addr = master_addr.clone();
let status_sender = status_sender.clone();
let shutdown = shutdown.clone();
let shutdown_sender = shutdown_sender.clone();
thread::spawn(move || {
shard_manager(
model_name,
uds_path,
shard_directory,
rank,
num_shard,
master_addr,
master_port,
status_sender,
shutdown,
shutdown_sender,
)
});
}
drop(shutdown_sender);
// Wait for shard to start
let mut shard_ready = 0;
while running.load(Ordering::SeqCst) {
match status_receiver.try_recv() {
Ok(ShardStatus::Ready) => {
shard_ready += 1;
if shard_ready == num_shard {
break;
}
}
Err(TryRecvError::Empty) => {
sleep(Duration::from_millis(100));
}
Ok(ShardStatus::Failed((rank, err))) => {
tracing::error!("Shard {} failed to start:\n{}", rank, err);
shutdown_shards(shutdown, &shutdown_receiver);
return ExitCode::FAILURE;
}
Err(TryRecvError::Disconnected) => {
tracing::error!("Shard status channel disconnected");
shutdown_shards(shutdown, &shutdown_receiver);
return ExitCode::FAILURE;
}
}
}
// We might have received a termination signal
if !running.load(Ordering::SeqCst) {
shutdown_shards(shutdown, &shutdown_receiver);
return ExitCode::SUCCESS;
}
// All shard started
// Start webserver
tracing::info!("Starting Webserver");
let mut webserver = match Popen::create(
&[
"text-generation-router",
"--max-concurrent-requests",
&max_concurrent_requests.to_string(),
"--max-input-length",
&max_input_length.to_string(),
"--max-batch-size",
&max_batch_size.to_string(),
"--max-waiting-time",
&max_waiting_time.to_string(),
"--port",
&port.to_string(),
"--master-shard-uds-path",
&format!("{}-0", shard_uds_path),
"--tokenizer-name",
&model_name,
],
PopenConfig {
stdout: Redirection::Pipe,
stderr: Redirection::Pipe,
// Needed for the shutdown procedure
setpgid: true,
..Default::default()
},
) {
Ok(p) => p,
Err(err) => {
tracing::error!("Failed to start webserver: {}", err);
if let PopenError::IoError(err) = err {
if err.kind() == io::ErrorKind::NotFound {
tracing::error!("text-generation-router not found in PATH");
tracing::error!("Please install it with `make install-router`")
}
}
shutdown_shards(shutdown, &shutdown_receiver);
return ExitCode::FAILURE;
}
};
// Redirect STDOUT and STDERR to the console
let webserver_stdout = webserver.stdout.take().unwrap();
let webserver_stderr = webserver.stderr.take().unwrap();
thread::spawn(move || {
let stdout = BufReader::new(webserver_stdout);
let stderr = BufReader::new(webserver_stderr);
for line in stdout.lines() {
println!("{}", line.unwrap());
}
for line in stderr.lines() {
println!("{}", line.unwrap());
}
});
// Default exit code
let mut exit_code = ExitCode::SUCCESS;
while running.load(Ordering::SeqCst) {
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
tracing::error!("Shard {} failed:\n{}", rank, err);
exit_code = ExitCode::FAILURE;
break;
};
match webserver.poll() {
Some(_) => {
tracing::error!("Webserver Crashed");
shutdown_shards(shutdown, &shutdown_receiver);
return ExitCode::FAILURE;
}
None => {
sleep(Duration::from_millis(100));
}
};
}
// Graceful termination
webserver.terminate().unwrap();
tracing::info!("Waiting for webserver to gracefully shutdown");
webserver.wait_timeout(Duration::from_secs(90)).unwrap();
tracing::info!("Webserver terminated");
shutdown_shards(shutdown, &shutdown_receiver);
exit_code
}
#[derive(Debug)]
enum ShardStatus {
Ready,
Failed((usize, String)),
}
#[allow(clippy::too_many_arguments)]
fn shard_manager(
model_name: String,
uds_path: String,
shard_directory: Option<String>,
rank: usize,
world_size: usize,
master_addr: String,
master_port: usize,
status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<Mutex<bool>>,
_shutdown_sender: mpsc::Sender<()>,
) {
// Get UDS path
let uds_string = format!("{}-{}", uds_path, rank);
let uds = Path::new(&uds_string);
// Clean previous runs
fs::remove_file(uds).unwrap_or_default();
// Process args
let mut shard_argv = vec![
"bloom-inference-server".to_string(),
"serve".to_string(),
model_name,
"--uds-path".to_string(),
uds_path,
];
if world_size > 1 {
shard_argv.push("--sharded".to_string());
}
if let Some(shard_directory) = shard_directory {
shard_argv.push("--shard-directory".to_string());
shard_argv.push(shard_directory);
}
// Start process
tracing::info!("Starting shard {}", rank);
let mut p = match Popen::create(
&shard_argv,
PopenConfig {
stdout: Redirection::Pipe,
stderr: Redirection::Pipe,
// Needed for the shutdown procedure
setpgid: true,
// NCCL env vars
env: Some(vec![
("RANK".parse().unwrap(), rank.to_string().parse().unwrap()),
(
"WORLD_SIZE".parse().unwrap(),
world_size.to_string().parse().unwrap(),
),
("MASTER_ADDR".parse().unwrap(), master_addr.parse().unwrap()),
(
"MASTER_PORT".parse().unwrap(),
master_port.to_string().parse().unwrap(),
),
]),
..Default::default()
},
) {
Ok(p) => p,
Err(err) => {
if let PopenError::IoError(ref err) = err {
if err.kind() == io::ErrorKind::NotFound {
tracing::error!("bloom-inference-server not found in PATH");
tracing::error!("Please install it with `make install-server`")
}
}
status_sender
.send(ShardStatus::Failed((rank, err.to_string())))
.unwrap();
return;
}
};
let mut ready = false;
let start_time = Instant::now();
loop {
// Process exited
if p.poll().is_some() {
let mut err = String::new();
p.stderr.take().unwrap().read_to_string(&mut err).unwrap();
status_sender
.send(ShardStatus::Failed((rank, err)))
.unwrap();
return;
}
// We received a shutdown signal
if *shutdown.lock().unwrap() {
p.terminate().unwrap();
let _ = p.wait_timeout(Duration::from_secs(90));
tracing::info!("Shard {} terminated", rank);
return;
}
// Shard is ready
if uds.exists() && !ready {
tracing::info!("Shard {} ready in {:?}", rank, start_time.elapsed());
status_sender.send(ShardStatus::Ready).unwrap();
ready = true;
} else if !ready {
tracing::info!("Waiting for shard {} to be ready...", rank);
}
sleep(Duration::from_secs(5));
}
}
fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receiver<()>) {
tracing::info!("Shutting down shards");
// Update shutdown value to true
// This will be picked up by the shard manager
{
let mut shutdown = shutdown.lock().unwrap();
*shutdown = true;
}
// Wait for shards to shutdown
// This will block till all shutdown_sender are dropped
let _ = shutdown_receiver.recv();
}
......@@ -11,10 +11,6 @@ service TextGenerationService {
rpc Generate (GenerateRequest) returns (GenerateResponse);
/// Generate tokens for a list of cached batches
rpc GenerateWithCache (GenerateWithCacheRequest) returns (GenerateWithCacheResponse);
/// Generate tokens until the text of at least one request of the batch is generated
rpc GenerateUntilFinished (GenerateUntilFinishedRequest) returns (GenerateUntilFinishedResponse);
/// Generate tokens until the text of at least one request of the cached batches i finished
rpc GenerateUntilFinishedWithCache (GenerateUntilFinishedWithCacheRequest) returns (GenerateUntilFinishedWithCacheResponse);
}
/// Empty request
......@@ -92,27 +88,3 @@ message GenerateWithCacheResponse {
/// Next batch (cached)
optional Batch batch = 2;
}
message GenerateUntilFinishedRequest {
/// Batch
Batch batch = 1;
}
message GenerateUntilFinishedResponse {
/// Finished requests
repeated GeneratedText generated_texts = 1;
/// Next batch (cached)
optional Batch batch = 2;
}
message GenerateUntilFinishedWithCacheRequest {
/// Cached batches
repeated Batch batches = 1;
}
message GenerateUntilFinishedWithCacheResponse {
/// Finished requests
repeated GeneratedText generated_texts = 1;
/// Next batch (cached)
optional Batch batch = 2;
}
......@@ -22,16 +22,7 @@ serde = "1.0.145"
serde_json = "1.0.85"
thiserror = "1.0.37"
tokenizers = "0.13.0"
tokio = { version = "1.21.1", features = ["rt-multi-thread", "parking_lot", "sync"] }
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tracing = "0.1.36"
tracing-subscriber = "0.3.15"
[workspace]
members = [
"client",
]
[profile.release]
debug = 1
incremental = true
lto = "off"
......@@ -5,8 +5,6 @@ edition = "2021"
[dependencies]
futures = "^0.3"
#grpc-error-details = { path = "../../grpc-error-details" }
#grpc-metadata = { path = "../../grpc-metadata" }
prost = "^0.9"
thiserror = "^1.0"
tokio = { version = "^1.21", features = ["sync"] }
......
/// Single shard Client
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
use crate::pb::generate::v1::*;
use crate::Result;
use tonic::transport::{Channel, Uri};
use tracing::*;
/// BLOOM Inference gRPC client
/// Text Generation Inference gRPC client
#[derive(Clone)]
pub struct Client {
stub: TextGenerationServiceClient<Channel>,
......@@ -34,6 +35,7 @@ impl Client {
})
}
/// Returns a list of uris or unix sockets of all shards
#[instrument(skip(self))]
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
let request = tonic::Request::new(ServiceDiscoveryRequest {});
......@@ -46,6 +48,7 @@ impl Client {
.into_inner()
.urls
.into_iter()
// Remove unix socket prefix
.map(|url| match url.strip_prefix("unix://") {
None => url,
Some(stripped_url) => stripped_url.to_string(),
......@@ -54,6 +57,7 @@ impl Client {
Ok(urls)
}
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self) -> Result<()> {
let request = tonic::Request::new(ClearCacheRequest {});
......@@ -64,6 +68,10 @@ impl Client {
Ok(())
}
/// Generate one token for each request in the given batch
///
/// Returns a list of generated texts of request that met their stopping criteria
/// and the next cached batch
#[instrument(skip(self))]
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
let request = tonic::Request::new(GenerateRequest { batch: Some(batch) });
......@@ -76,6 +84,10 @@ impl Client {
Ok((response.generated_texts, response.batch))
}
/// Generate one token for each request in the given cached batch
///
/// Returns a list of generated texts of request that met their stopping criteria
/// and the next cached batch
#[instrument(skip(self))]
pub async fn generate_with_cache(
&mut self,
......@@ -90,34 +102,4 @@ impl Client {
.into_inner();
Ok((response.generated_texts, response.batch))
}
#[instrument(skip(self))]
pub async fn generate_until_finished(
&mut self,
batch: Batch,
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
let request = tonic::Request::new(GenerateUntilFinishedRequest { batch: Some(batch) });
let response = self
.stub
.generate_until_finished(request)
.instrument(info_span!("generate_until_finished"))
.await?
.into_inner();
Ok((response.generated_texts, response.batch))
}
#[instrument(skip(self))]
pub async fn generate_until_finished_with_cache(
&mut self,
batches: Vec<Batch>,
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
let request = tonic::Request::new(GenerateUntilFinishedWithCacheRequest { batches });
let response = self
.stub
.generate_until_finished_with_cache(request)
.instrument(info_span!("generate_until_finished_with_cache"))
.await?
.into_inner();
Ok((response.generated_texts, response.batch))
}
}
//! BLOOM Inference gRPC client library
//! Text Generation gRPC client library
mod client;
#[allow(clippy::derive_partial_eq_without_eq)]
mod pb;
mod sharded_client;
......@@ -8,7 +9,7 @@ pub use client::Client;
pub use pb::generate::v1::{Batch, GeneratedText, LogitsWarperParameters, Request};
pub use sharded_client::ShardedClient;
use thiserror::Error;
pub use tonic::transport;
use tonic::transport;
use tonic::Status;
#[derive(Error, Debug, Clone)]
......@@ -21,7 +22,7 @@ pub enum ClientError {
impl From<Status> for ClientError {
fn from(err: Status) -> Self {
Self::Generation(err.to_string())
Self::Generation(err.message().to_string())
}
}
......
/// Multi shard Client
use crate::Result;
use crate::{Batch, Client, GeneratedText};
use futures::future::join_all;
use tokio::sync::{broadcast, mpsc};
use tonic::transport::Uri;
/// List of all available commands that can be sent through the command channel
#[derive(Clone, Debug)]
enum Command {
Generate(
......@@ -14,36 +16,32 @@ enum Command {
Vec<Batch>,
mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
),
GenerateUntilFinished(
Batch,
mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
),
GenerateUntilFinishedWithCache(
Vec<Batch>,
mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
),
ClearCache(mpsc::Sender<Result<()>>),
}
/// Tokio task that handles the communication with a single shard
///
/// We subscribe on a broadcast channel to receive commands that will be sent by
/// the ShardedClient.
///
/// Each command is fan out to all shards.
///
/// The result of the command is sent back to the ShardedClient through a mpsc channel (multi
/// producer = the shards, single consumer = the ShardedClient).
async fn client_task(mut client: Client, mut request_subscriber: broadcast::Receiver<Command>) {
while let Ok(message) = request_subscriber.recv().await {
match message {
Command::Generate(batch, response_tx) => {
let result = client.generate(batch).await;
// We can unwrap_or(()) here because the only error that can happen is if the
// receiver is dropped, which means that the ShardedClient already received a
// response from another shard
response_tx.try_send(result).unwrap_or(());
}
Command::GenerateWithCache(batches, response_tx) => {
let result = client.generate_with_cache(batches).await;
response_tx.try_send(result).unwrap_or(());
}
Command::GenerateUntilFinished(batch, response_tx) => {
let result = client.generate_until_finished(batch).await;
response_tx.try_send(result).unwrap_or(());
}
Command::GenerateUntilFinishedWithCache(batches, response_tx) => {
let result = client.generate_until_finished_with_cache(batches).await;
response_tx.try_send(result).unwrap_or(());
}
Command::ClearCache(response_tx) => {
let result = client.clear_cache().await;
response_tx.try_send(result).unwrap_or(());
......@@ -52,30 +50,42 @@ async fn client_task(mut client: Client, mut request_subscriber: broadcast::Rece
}
}
/// Text Generation Inference gRPC multi client
pub struct ShardedClient {
_clients: Vec<Client>,
request_tx: broadcast::Sender<Command>,
}
impl ShardedClient {
fn new(mut clients: Vec<Client>) -> Self {
fn new(clients: Vec<Client>) -> Self {
// The broadcast channel to communicate with the shards
// We use a capacity of one as the shards are not asynchronous and can only process one
// command at a time
let (request_tx, _) = broadcast::channel(1);
for client in clients.drain(..) {
// Spawn client tasks
for client in clients.iter() {
let request_subscriber = request_tx.subscribe();
tokio::spawn(client_task(client, request_subscriber));
tokio::spawn(client_task(client.clone(), request_subscriber));
}
Self { request_tx }
Self {
_clients: clients,
request_tx,
}
}
/// Create a new ShardedClient from a master client. The master client will communicate with
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
async fn from_master_client(mut master_client: Client) -> Result<Self> {
// Get all uris/unix sockets from the master client
let uris = master_client.service_discovery().await.unwrap();
let futures = uris.into_iter().map(|path| Client::connect_uds(path));
let futures = uris.into_iter().map(Client::connect_uds);
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
Ok(Self::new(clients?))
}
/// Returns a client connected to the given url
/// Returns a client connected to the given uri
pub async fn connect(uri: Uri) -> Result<Self> {
let master_client = Client::connect(uri).await?;
Self::from_master_client(master_client).await
......@@ -87,51 +97,43 @@ impl ShardedClient {
Self::from_master_client(master_client).await
}
/// Generate one token for each request in the given batch
///
/// Returns a list of generated texts of request that met their stopping criteria
/// and the next cached batch
pub async fn generate(&self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
// Create a channel to receive the response from the shards
// We will only ever receive one message on this channel
let (response_tx, mut response_rx) = mpsc::channel(1);
self.request_tx
.send(Command::Generate(batch, response_tx))
.unwrap();
// As soon as we receive one response, we can return as all shards will return the same
response_rx.recv().await.unwrap()
}
/// Generate one token for each request in the given cached batch
///
/// Returns a list of generated texts of request that met their stopping criteria
/// and the next cached batch
pub async fn generate_with_cache(
&self,
batches: Vec<Batch>,
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
// Create a channel to receive the response from the shards
// We will only ever receive one message on this channel
let (response_tx, mut response_rx) = mpsc::channel(1);
self.request_tx
.send(Command::GenerateWithCache(batches, response_tx))
.unwrap();
// As soon as we receive one response, we can return as all shards will return the same
response_rx.recv().await.unwrap()
}
pub async fn generate_until_finished(
&self,
batch: Batch,
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
let (response_tx, mut response_rx) = mpsc::channel(1);
self.request_tx
.send(Command::GenerateUntilFinished(batch, response_tx))
.unwrap();
response_rx.recv().await.unwrap()
}
pub async fn generate_until_finished_with_cache(
&self,
batches: Vec<Batch>,
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
let (response_tx, mut response_rx) = mpsc::channel(1);
self.request_tx
.send(Command::GenerateUntilFinishedWithCache(
batches,
response_tx,
))
.unwrap();
response_rx.recv().await.unwrap()
}
/// Clear the past generations cache
pub async fn clear_cache(&self) -> Result<()> {
// Create a channel to receive the response from the shards
// We will only ever receive one message on this channel
let (response_tx, mut response_rx) = mpsc::channel(1);
self.request_tx
.send(Command::ClearCache(response_tx))
......
use crate::server::GenerateRequest;
/// Batching and inference logic
use crate::GenerateRequest;
use crate::{Db, Entry};
use axum::http::StatusCode;
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
use std::future::Future;
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
use tokio::sync::{oneshot, Notify};
use tokio::time::Instant;
use tracing::instrument;
const MAX_LENGTH: usize = 128;
#[derive(Debug, Error)]
pub enum InferError {
#[error("Request failed during generation: {0}")]
GenerationError(String),
#[error("Model is overloaded")]
Overloaded,
}
impl From<InferError> for (StatusCode, String) {
fn from(err: InferError) -> Self {
match err {
InferError::GenerationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
InferError::Overloaded => (StatusCode::TOO_MANY_REQUESTS, err.to_string()),
}
}
}
/// Batcher
#[derive(Clone)]
pub struct Batcher {
/// Request database
db: Db,
/// Shared state
shared: Arc<Shared>,
}
/// Batcher shared state
struct Shared {
/// Batching background Tokio task notifier
batching_task: Notify,
}
impl Batcher {
pub(crate) fn new(client: ShardedClient, max_batch_size: usize) -> Self {
pub(crate) fn new(
client: ShardedClient,
max_batch_size: usize,
max_waiting_time: Duration,
) -> Self {
// Batcher shared state
let db = Db::new();
let shared = Arc::new(Shared {
batching_task: Notify::new(),
});
tokio::spawn(batching_task(max_batch_size, client, db.clone(), shared.clone()));
// Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task(
max_batch_size,
max_waiting_time,
client,
db.clone(),
shared.clone(),
));
Self { db, shared }
}
/// Add a new request to the database and return a future that will generate the text
pub(crate) async fn infer(
&self,
input_length: usize,
request: GenerateRequest,
) -> Result<String, InferError> {
if self.db.len() > MAX_LENGTH {
return Err(InferError::Overloaded);
}
let (request_tx, request_rx) = oneshot::channel();
// One shot channel to communicate with the background batching task
let (response_tx, response_rx) = oneshot::channel();
// Try to append the request to the database
self.db.append(Entry {
request,
response_tx: request_tx,
response_tx,
input_length,
time: Instant::now(),
});
// Notify the background task that we have a new entry in the database that needs
// to be batched
self.shared.batching_task.notify_waiters();
match request_rx.await.unwrap() {
// Await on the response from the background task
// We can safely unwrap as the background task will never drop the sender
match response_rx.await.unwrap() {
Ok(output) => Ok(output),
Err(err) => Err(InferError::GenerationError(err.to_string())),
}
}
}
async fn batching_task(max_batch_size: usize,
client: ShardedClient,
db: Db,
shared: Arc<Shared>) {
/// Batching logic
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
#[instrument(skip(client, db, shared))]
async fn batching_task(
max_batch_size: usize,
max_waiting_time: Duration,
client: ShardedClient,
db: Db,
shared: Arc<Shared>,
) {
// Minimum batch size after which we try to add more requests
let limit_min_batch_size = (max_batch_size / 2) as u32;
// Infinite loop
loop {
// Wait for a notification from the Batcher struct
shared.batching_task.notified().await;
if let Some(batch) = db.next_batch(max_batch_size) {
let request_ids = batch.requests.iter().map(|req| req.id).collect();
let mut cached_batch = match batch.size {
size if size > limit_min_batch_size => {
wrap_future(client.generate_until_finished(batch), request_ids, &db).await
}
_ => wrap_future(client.generate(batch), request_ids, &db).await,
};
// Get the next batch from the DB
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the DB
if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size, None) {
let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
// We loop until we do not receive any cached batch from the inference server (== until
// all requests have met their stopping criteria)
while let Some(batch) = cached_batch {
let mut current_batch_size = batch.size;
// Get current batch info
let batch_size = batch.size;
let mut request_ids: Vec<u64> = batch.requests.iter().map(|req| req.id).collect();
let mut batches = vec![batch];
if current_batch_size <= limit_min_batch_size {
if let Some(new_batch) = db.next_batch_minimum_size(limit_min_batch_size as usize, max_batch_size) {
let new_batch_request_ids =
new_batch.requests.iter().map(|req| req.id).collect();
// If the current batch is too small, we try to add more requests to it
if batch_size <= limit_min_batch_size {
// Get the next batch from the DB that meet our minimum size criteria
if let Some((new_request_ids, new_batch)) =
db.next_batch(Some(limit_min_batch_size as usize), max_batch_size, None)
{
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch =
wrap_future(client.generate(new_batch), new_batch_request_ids, &db)
.await;
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
batches.push(new_cached_batch);
}
}
// If we don't have enough requests to meet the minimum size criteria, we
// try to get the next batch from the DB that have been waiting over
// the max_waiting_time
else if let Some((new_request_ids, new_batch)) =
db.next_batch(None, max_batch_size, Some(max_waiting_time))
{
let new_cached_batch =
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
current_batch_size += new_cached_batch.size;
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
batches.push(new_cached_batch);
}
}
}
cached_batch = match current_batch_size {
size if size > limit_min_batch_size => {
wrap_future(
client.generate_until_finished_with_cache(batches),
request_ids,
&db,
)
.await
}
_ => wrap_future(client.generate_with_cache(batches), request_ids, &db).await,
};
cached_batch =
wrap_future(client.generate_with_cache(batches), request_ids, &db).await;
}
}
}
}
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
async fn wrap_future(
future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
request_ids: Vec<u64>,
......@@ -134,6 +163,7 @@ async fn wrap_future(
send_generated(generated_texts, db);
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
send_error(err, request_ids, db);
None
......@@ -141,16 +171,20 @@ async fn wrap_future(
}
}
/// Send errors to the Batcher for all `request_ids`
fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
request_ids.into_iter().for_each(|id| {
// We can `expect` here as the request id should always be in the DB
let entry = db.remove(&id).expect("ID not found in db. This is a bug.");
// unwrap_or is valid here as we don't care if the receiver is gone.
entry.response_tx.send(Err(error.clone())).unwrap_or(());
});
}
/// Send `generated_text` to the Batcher for all `finished`
fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
finished.into_iter().for_each(|output| {
// We can `expect` here as the request id should always be in the DB
let entry = db
.remove(&output.request.unwrap().id)
.expect("ID not found in db. This is a bug.");
......@@ -158,3 +192,18 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
entry.response_tx.send(Ok(output.output)).unwrap_or(());
});
}
#[derive(Debug, Error)]
pub enum InferError {
#[error("Request failed during generation: {0}")]
GenerationError(String),
}
/// Convert to Axum supported format
impl From<InferError> for (StatusCode, String) {
fn from(err: InferError) -> Self {
match err {
InferError::GenerationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
}
}
}
/// This code is massively inspired by Tokio mini-redis
use crate::server::{GenerateParameters, GenerateRequest};
use crate::{GenerateParameters, GenerateRequest};
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
use parking_lot::RwLock;
use parking_lot::Mutex;
use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::oneshot::Sender;
use tokio::time::Instant;
/// Database entry
#[derive(Debug)]
pub(crate) struct Entry {
/// Request
pub request: GenerateRequest,
/// Response sender to communicate between the Batcher and the batching_task
pub response_tx: Sender<Result<String, ClientError>>,
/// Number of tokens in the input
pub input_length: usize,
/// Instant when this entry was created
pub time: Instant,
}
impl From<GenerateParameters> for LogitsWarperParameters {
fn from(parameters: GenerateParameters) -> Self {
Self {
temperature: parameters.temperature,
top_k: parameters.top_k as u32,
top_p: parameters.top_p,
do_sample: parameters.do_sample,
}
}
}
/// Request Database
#[derive(Debug, Clone)]
pub(crate) struct Db {
pub shared: Arc<Shared>,
}
/// Shared state
#[derive(Debug)]
pub struct Shared {
state: RwLock<State>,
state: Mutex<State>,
}
/// Database State
#[derive(Debug)]
struct State {
/// Database entries organized in a BTreeMap to be able to iterate over them in order
entries: BTreeMap<u64, Entry>,
/// Identifier to use for the next expiration. Each expiration is associated
/// with a unique identifier. See above for why.
/// Id of the next entry
next_id: u64,
/// Id of the next batch
next_batch_id: u64,
/// Current batch id
/// Start ID of the next batch. Used to iterate inside the entries BTreeMap
next_batch_start_id: u64,
}
impl State {
/// Get the next requests
fn next_requests(
&self,
max_size: usize,
min_waiting_time: Option<Duration>,
) -> Option<(Vec<u64>, Vec<Request>)> {
// Iterates for max_size over the BTreemap starting from next_batch_start_id
let mut requests = Vec::new();
let mut ids = Vec::new();
for (id, entry) in self
.entries
// Start from next_batch_start_id
.range(self.next_batch_start_id..)
// Take max_size
.take(max_size)
{
if let Some(min_waiting_time) = min_waiting_time {
// Only take entries that waited for at least min_waiting_time
if entry.time.elapsed() < min_waiting_time {
// Since entries are ordered, we already know that all following entries won't
// satisfy the condition
break;
}
}
requests.push(Request {
id: *id,
inputs: entry.request.inputs.clone(),
input_length: entry.input_length as u32,
parameters: Some(LogitsWarperParameters::from(
entry.request.parameters.clone(),
)),
max_new_tokens: entry.request.parameters.max_new_tokens,
});
ids.push(*id);
}
if requests.is_empty() {
None
} else {
Some((ids, requests))
}
}
}
impl Db {
pub(crate) fn new() -> Self {
// Shared state
let shared = Arc::new(Shared {
state: RwLock::new(State {
state: Mutex::new(State {
entries: BTreeMap::new(),
next_id: 0,
next_batch_id: 0,
......@@ -62,55 +112,46 @@ impl Db {
Self { shared }
}
/// Append an entry to the database
pub(crate) fn append(&self, entry: Entry) {
let mut state = self.shared.state.write();
// Acquire lock
let mut state = self.shared.state.lock();
// Insert entry
let id = state.next_id;
state.next_id += 1;
state.entries.insert(id, entry);
}
/// Remove an entry from the database if it exists
pub(crate) fn remove(&self, id: &u64) -> Option<Entry> {
let mut state = self.shared.state.write();
let mut state = self.shared.state.lock();
state.entries.remove(id)
}
pub(crate) fn len(&self) -> usize {
let state = self.shared.state.read();
state.entries.len()
}
fn next_requests(&self, max_size: usize) -> Option<(u64, Vec<Request>)> {
let state = self.shared.state.read();
let requests: Vec<Request> = state
.entries
.range(state.next_batch_start_id..)
.take(max_size)
.map(|(id, entry)| Request {
id: *id,
inputs: entry.request.inputs.clone(),
input_length: entry.input_length as u32,
parameters: Some(LogitsWarperParameters::from(
entry.request.parameters.clone(),
)),
max_new_tokens: entry.request.parameters.max_new_tokens,
})
.collect();
if requests.is_empty() {
None
} else {
let last_id = requests.last().unwrap().id;
Some((last_id, requests))
}
}
// Get the next batch
pub(crate) fn next_batch(
&self,
min_size: Option<usize>,
max_size: usize,
min_waiting_time: Option<Duration>,
) -> Option<(Vec<u64>, Batch)> {
// Acquire lock
let mut state = self.shared.state.lock();
// Get requests from the database
if let Some((ids, requests)) = state.next_requests(max_size, min_waiting_time) {
if let Some(min_size) = min_size {
// If min_size is set, only return a batch if there are enough requests
if requests.len() < min_size {
return None;
}
}
pub(crate) fn next_batch(&self, max_size: usize) -> Option<Batch> {
if let Some((last_id, requests)) = self.next_requests(max_size) {
let mut state = self.shared.state.write();
// Batch size
let size = requests.len();
// Longest input length for all requests in batch size
// Used for padding inside the inference server
let max_sequence_length = requests.iter().map(|r| r.input_length).max().unwrap();
let batch = Batch {
id: state.next_batch_id,
......@@ -118,34 +159,24 @@ impl Db {
size: size as u32,
max_sequence_length,
};
state.next_batch_start_id = last_id + 1;
// Update next_batch_start_id to the last id in the batch + 1
state.next_batch_start_id = ids.last().unwrap() + 1;
// Increment batch id
state.next_batch_id += 1;
return Some(batch);
return Some((ids, batch));
}
None
}
}
pub(crate) fn next_batch_minimum_size(
&self,
min_size: usize,
max_size: usize,
) -> Option<Batch> {
if let Some((last_id, requests)) = self.next_requests(max_size) {
if requests.len() >= min_size {
let mut state = self.shared.state.write();
let size = requests.len();
let max_sequence_length = requests.iter().map(|r| r.input_length).max().unwrap();
let batch = Batch {
id: state.next_batch_id,
requests,
size: size as u32,
max_sequence_length,
};
state.next_batch_start_id = last_id + 1;
state.next_batch_id += 1;
return Some(batch);
}
impl From<GenerateParameters> for LogitsWarperParameters {
fn from(parameters: GenerateParameters) -> Self {
Self {
temperature: parameters.temperature,
top_k: parameters.top_k as u32,
top_p: parameters.top_p,
do_sample: parameters.do_sample,
}
None
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment