Unverified Commit 20c3c594 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(router): refactor API and add openAPI schemas (#53)

parent b1482d90
......@@ -5,6 +5,8 @@ on:
push:
branches:
- 'main'
tags:
- 'v*'
pull_request:
branches:
- 'main'
......@@ -43,6 +45,8 @@ jobs:
ghcr.io/huggingface/text-generation-inference
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
tags: |
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=raw,value=latest,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}
- name: Build and push Docker image
......
......@@ -83,9 +83,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "axum"
version = "0.5.17"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acee9fd5073ab6b045a275b3e709c163dd36c90685219cb21804a147b58dba43"
checksum = "e5694b64066a2459918d8074c2ce0d5a88f409431994c2356617c8ae0c4721fc"
dependencies = [
"async-trait",
"axum-core",
......@@ -101,8 +101,10 @@ dependencies = [
"mime",
"percent-encoding",
"pin-project-lite",
"rustversion",
"serde",
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
"sync_wrapper",
"tokio",
......@@ -114,9 +116,9 @@ dependencies = [
[[package]]
name = "axum-core"
version = "0.2.9"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37e5939e02c56fecd5c017c37df4238c0a839fa76b7f97acdd7efb804fd181cc"
checksum = "1cae3e661676ffbacb30f1a824089a8c9150e71017f7e1e38f2aa32009188d34"
dependencies = [
"async-trait",
"bytes",
......@@ -124,6 +126,7 @@ dependencies = [
"http",
"http-body",
"mime",
"rustversion",
"tower-layer",
"tower-service",
]
......@@ -207,7 +210,7 @@ dependencies = [
"tar",
"tempfile",
"thiserror",
"zip",
"zip 0.5.13",
"zip-extensions",
]
......@@ -465,6 +468,15 @@ dependencies = [
"dirs-sys",
]
[[package]]
name = "dirs"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059"
dependencies = [
"dirs-sys",
]
[[package]]
name = "dirs-sys"
version = "0.3.7"
......@@ -867,6 +879,7 @@ checksum = "10a35a97730320ffe8e2d410b5d3b69279b98d2c14bdb8b70ea89ecf7888d41e"
dependencies = [
"autocfg",
"hashbrown",
"serde",
]
[[package]]
......@@ -999,9 +1012,9 @@ checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d"
[[package]]
name = "matchit"
version = "0.5.0"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb"
checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40"
[[package]]
name = "memchr"
......@@ -1024,6 +1037,16 @@ version = "0.3.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d"
[[package]]
name = "mime_guess"
version = "2.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minimal-lexical"
version = "0.2.1"
......@@ -1552,12 +1575,62 @@ dependencies = [
"winreg",
]
[[package]]
name = "rust-embed"
version = "6.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "283ffe2f866869428c92e0d61c2f35dfb4355293cdfdc48f49e895c15f1333d1"
dependencies = [
"rust-embed-impl",
"rust-embed-utils",
"walkdir",
]
[[package]]
name = "rust-embed-impl"
version = "6.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "31ab23d42d71fb9be1b643fe6765d292c5e14d46912d13f3ae2815ca048ea04d"
dependencies = [
"proc-macro2",
"quote",
"rust-embed-utils",
"shellexpand",
"syn",
"walkdir",
]
[[package]]
name = "rust-embed-utils"
version = "7.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1669d81dfabd1b5f8e2856b8bbe146c6192b0ba22162edc738ac0a5de18f054"
dependencies = [
"sha2",
"walkdir",
]
[[package]]
name = "rustversion"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5583e89e108996506031660fe09baa5011b9dd0341b89029313006d1fb508d70"
[[package]]
name = "ryu"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4501abdff3ae82a1c1b477a17252eb69cee9e66eb915c1abaa4f44d873df9f09"
[[package]]
name = "same-file"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
dependencies = [
"winapi-util",
]
[[package]]
name = "schannel"
version = "0.1.20"
......@@ -1628,6 +1701,15 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_path_to_error"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26b04f22b563c91331a10074bda3dd5492e3cc39d56bd557e91c0af42b6c7341"
dependencies = [
"serde",
]
[[package]]
name = "serde_urlencoded"
version = "0.7.1"
......@@ -1660,6 +1742,15 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "shellexpand"
version = "2.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ccc8076840c4da029af4f87e4e8daeb0fca6b87bbb02e10cb60b791450e11e4"
dependencies = [
"dirs 4.0.0",
]
[[package]]
name = "signal-hook-registry"
version = "1.4.0"
......@@ -1797,7 +1888,7 @@ dependencies = [
[[package]]
name = "text-generation-client"
version = "0.1.0"
version = "0.2.0"
dependencies = [
"futures",
"prost",
......@@ -1812,7 +1903,7 @@ dependencies = [
[[package]]
name = "text-generation-launcher"
version = "0.1.0"
version = "0.2.0"
dependencies = [
"clap 4.0.22",
"ctrlc",
......@@ -1827,7 +1918,7 @@ dependencies = [
[[package]]
name = "text-generation-router"
version = "0.1.0"
version = "0.2.0"
dependencies = [
"async-stream",
"axum",
......@@ -1845,6 +1936,8 @@ dependencies = [
"tokio-stream",
"tracing",
"tracing-subscriber",
"utoipa",
"utoipa-swagger-ui",
]
[[package]]
......@@ -1921,7 +2014,7 @@ dependencies = [
"cached-path",
"clap 2.34.0",
"derive_builder",
"dirs",
"dirs 3.0.2",
"esaxx-rs",
"getrandom",
"indicatif 0.15.0",
......@@ -2234,6 +2327,15 @@ version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987"
[[package]]
name = "unicase"
version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50f37be617794602aabbeee0be4f259dc1778fabe05e2d67ee8f79326d5cb4f6"
dependencies = [
"version_check",
]
[[package]]
name = "unicode-bidi"
version = "0.3.8"
......@@ -2293,6 +2395,46 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "utoipa"
version = "3.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3920fa753064b1be7842bea26175ffa0dfc4a8f30bcb52b8ff03fddf8889914c"
dependencies = [
"indexmap",
"serde",
"serde_json",
"utoipa-gen",
]
[[package]]
name = "utoipa-gen"
version = "3.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "720298fac6efca20df9e457e67a1eab41a20d1c3101380b5c4dca1ca60ae0062"
dependencies = [
"proc-macro-error",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "utoipa-swagger-ui"
version = "3.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae3d4f4da6408f0f20ff58196ed619c94306ab32635aeca3d3fa0768c0bd0de2"
dependencies = [
"axum",
"mime_guess",
"regex",
"rust-embed",
"serde",
"serde_json",
"utoipa",
"zip 0.6.4",
]
[[package]]
name = "valuable"
version = "0.1.0"
......@@ -2317,6 +2459,17 @@ version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "walkdir"
version = "2.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "808cf2735cd4b6866113f648b791c6adc5714537bc222d9347bb203386ffda56"
dependencies = [
"same-file",
"winapi",
"winapi-util",
]
[[package]]
name = "want"
version = "0.3.0"
......@@ -2589,11 +2742,23 @@ dependencies = [
"time",
]
[[package]]
name = "zip"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0445d0fbc924bb93539b4316c11afb121ea39296f99a3c4c9edad09e3658cdef"
dependencies = [
"byteorder",
"crc32fast",
"crossbeam-utils",
"flate2",
]
[[package]]
name = "zip-extensions"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a64c3c977bc3434ce2d4bcea8ad3c644672de0f2c402b72b9171ca80a8885d14"
dependencies = [
"zip",
"zip 0.5.13",
]
......@@ -4,9 +4,6 @@ members = [
"router/client",
"launcher"
]
exclude = [
"server/safetensors",
]
[profile.release]
debug = 1
......
......@@ -26,21 +26,18 @@ FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
ENV LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \
DEBIAN_FRONTEND=noninteractive \
MODEL_BASE_PATH=/data \
HUGGINGFACE_HUB_CACHE=/data \
MODEL_ID=bigscience/bloom-560m \
QUANTIZE=false \
NUM_GPUS=1 \
NUM_SHARD=1 \
SAFETENSORS_FAST_GPU=1 \
PORT=80 \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NCCL_ASYNC_ERROR_HANDLING=1 \
CUDA_HOME=/usr/local/cuda \
LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
CONDA_DEFAULT_ENV=text-generation \
PATH=$PATH:/opt/miniconda/envs/text-generation/bin:/opt/miniconda/bin:/usr/local/cuda/bin
SHELL ["/bin/bash", "-c"]
RUN apt-get update && apt-get install -y unzip curl libssl-dev && rm -rf /var/lib/apt/lists/*
RUN cd ~ && \
......@@ -71,4 +68,5 @@ COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/loca
# Install launcher
COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher
CMD HUGGINGFACE_HUB_CACHE=$MODEL_BASE_PATH text-generation-launcher --num-shard $NUM_GPUS --model-name $MODEL_ID --json-output
\ No newline at end of file
ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]
\ No newline at end of file
......@@ -15,17 +15,23 @@ server-dev:
router-dev:
cd router && cargo run
integration-tests: install-router install-launcher
cargo test
python-tests:
cd server && pytest tests
run-bloom-560m:
text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2
run-bloom-560m-quantize:
text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2 --quantize
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --quantize
download-bloom:
text-generation-server download-weights bigscience/bloom
run-bloom:
text-generation-launcher --model-name bigscience/bloom --num-shard 8
text-generation-launcher --model-id bigscience/bloom --num-shard 8
run-bloom-quantize:
text-generation-launcher --model-name bigscience/bloom --num-shard 8 --quantize
\ No newline at end of file
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize
\ No newline at end of file
<div align="center">
# Text Generation Inference
<div align="center">
<a href="https://github.com/huggingface/text-generation-inference">
<img alt="GitHub Repo stars" src="https://img.shields.io/github/stars/huggingface/text-generation-inference?style=social">
</a>
<a href="https://github.com/huggingface/text-generation-inference/blob/main/LICENSE">
<img alt="License" src="https://img.shields.io/github/license/huggingface/text-generation-inference">
</a>
<a href="https://huggingface.github.io/text-generation-inference">
<img alt="Swagger API documentation" src="https://img.shields.io/badge/API-Swagger-informational">
</a>
![architecture](assets/architecture.jpg)
</div>
A Rust and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co)
to power Bloom, BloomZ and MT0-XXL api-inference widgets.
A Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co)
to power LLMs api-inference widgets.
## Table of contents
- [Features](#features)
- [Officially Supported Models](#officially-supported-models)
- [Get Started](#get-started)
- [Docker](#docker)
- [Local Install](#local-install)
- [OpenAPI](#api-documentation)
- [CUDA Kernels](#cuda-kernels)
- [Run BLOOM](#run-bloom)
- [Download](#download)
- [Run](#run)
- [Quantization](#quantization)
- [Develop](#develop)
- [Testing](#testing)
## Features
- Token streaming using Server Side Events (SSE)
- [Dynamic batching of incoming requests](https://github.com/huggingface/text-generation-inference/blob/main/router/src/batcher.rs#L88) for increased total throughput
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
......@@ -36,30 +63,63 @@ or
`AutoModelForSeq2SeqLM.from_pretrained(<model>, device_map="auto")`
## Load Tests for BLOOM
## Get started
### Docker
See `k6/load_test.js`
The easiest way of getting started is using the official Docker container:
```shell
model=bigscience/bloom-560m
num_shard=2
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --num-shard $num_shard
```
| | avg | min | med | max | p(90) | p(95) | RPS |
|--------------------------------------------------------------|-----------|--------------|-----------|------------|-----------|-----------|----------|
| [Original code](https://github.com/huggingface/transformers_bloom_parallel) | 8.9s | 1s | 9.12s | 16.69s | 13.7s | 14.26s | 5.9 |
| New batching logic | **5.44s** | **959.53ms** | **5.28s** | **13.12s** | **7.78s** | **8.92s** | **9.08** |
You can then query the model using either the `/generate` or `/generate_stream` routes:
## Install
```shell
curl 127.0.0.1:8080/generate \
-X POST \
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
-H 'Content-Type: application/json'
```
```shell
make install
curl 127.0.0.1:8080/generate_stream \
-X POST \
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
-H 'Content-Type: application/json'
```
## Run
To use GPUs, you will need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html).
### API documentation
### BLOOM 560-m
You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route.
The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference).
### Local install
You can also opt to install `text-generation-inference` locally. You will need to have cargo and Python installed on your
machine
```shell
BUILD_EXTENSIONS=True make install # Install repository and HF/transformer fork with CUDA kernels
make run-bloom-560m
```
### BLOOM
### CUDA Kernels
The custom CUDA kernels are only tested on NVIDIA A100s. If you have any installation or runtime issues, you can remove
the kernels by using the `BUILD_EXTENSIONS=False` environment variable.
Be aware that the official Docker image has them enabled by default.
## Run BLOOM
### Download
First you need to download the weights:
......@@ -67,29 +127,30 @@ First you need to download the weights:
make download-bloom
```
### Run
```shell
make run-bloom # Requires 8xA100 80GB
```
### Quantization
You can also quantize the weights with bitsandbytes to reduce the VRAM requirement:
```shell
make run-bloom-quantize # Requires 8xA100 40GB
```
## Test
## Develop
```shell
curl 127.0.0.1:3000/generate \
-v \
-X POST \
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
-H 'Content-Type: application/json'
make server-dev
make router-dev
```
## Develop
## Testing
```shell
make server-dev
make router-dev
make python-tests
make integration-tests
```
\ No newline at end of file
......@@ -4,9 +4,9 @@ endpoint_name: bloom-inference
model: azureml:bloom:1
model_mount_path: /var/azureml-model
environment_variables:
MODEL_BASE_PATH: /var/azureml-model/bloom
HUGGINGFACE_HUB_CACHE: /var/azureml-model/bloom
MODEL_ID: bigscience/bloom
NUM_GPUS: 8
NUM_SHARD: 8
environment:
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference:0.3.1
inference_config:
......
assets/architecture.jpg

132 KB | W: | H:

assets/architecture.jpg

334 KB | W: | H:

assets/architecture.jpg
assets/architecture.jpg
assets/architecture.jpg
assets/architecture.jpg
  • 2-up
  • Swipe
  • Onion skin
<html>
<head>
<!-- Load the latest Swagger UI code and style from npm using unpkg.com -->
<script src="https://unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script>
<link rel="stylesheet" type="text/css" href="https://unpkg.com/swagger-ui-dist@3/swagger-ui.css"/>
<title>Text Generation Inference API</title>
</head>
<body>
<div id="swagger-ui"></div> <!-- Div to hold the UI component -->
<script>
window.onload = function () {
// Begin Swagger UI call region
const ui = SwaggerUIBundle({
url: "openapi.json", //Location of Open API spec in the repo
dom_id: '#swagger-ui',
deepLinking: true,
supportedSubmitMethods: [],
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
plugins: [
SwaggerUIBundle.plugins.DownloadUrl
],
})
window.ui = ui
}
</script>
</body>
</html>
\ No newline at end of file
{
"openapi": "3.0.3",
"info": {
"title": "Text Generation Inference",
"description": "Text Generation Webserver",
"contact": {
"name": "Olivier Dehaene",
"email": ""
},
"license": {
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
},
"version": "0.2.0"
},
"paths": {
"/generate": {
"post": {
"tags": [
"Text Generation Inference"
],
"summary": "Generate tokens",
"description": "Generate tokens",
"operationId": "generate",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/GenerateRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Generated Text",
"content": {
"application/json": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/GenerateResponse"
}
}
}
}
},
"422": {
"description": "Input validation error",
"content": {
"application/json": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse"
}
},
"example": {
"error": "Input validation error"
}
}
}
},
"424": {
"description": "Generation Error",
"content": {
"application/json": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse"
}
},
"example": {
"error": "Request failed during generation"
}
}
}
},
"429": {
"description": "Model is overloaded",
"content": {
"application/json": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse"
}
},
"example": {
"error": "Model is overloaded"
}
}
}
},
"500": {
"description": "Incomplete generation",
"content": {
"application/json": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse"
}
},
"example": {
"error": "Incomplete generation"
}
}
}
}
},
"deprecated": false
}
},
"/generate_stream": {
"post": {
"tags": [
"Text Generation Inference"
],
"summary": "Generate a stream of token using Server Side Events",
"description": "Generate a stream of token using Server Side Events",
"operationId": "generate_stream",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/GenerateRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Generated Text",
"content": {
"text/event-stream ": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/StreamResponse"
}
}
}
}
},
"422": {
"description": "Input validation error",
"content": {
"text/event-stream ": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse"
}
},
"example": {
"error": "Input validation error"
}
}
}
},
"424": {
"description": "Generation Error",
"content": {
"text/event-stream ": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse"
}
},
"example": {
"error": "Request failed during generation"
}
}
}
},
"429": {
"description": "Model is overloaded",
"content": {
"text/event-stream ": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse"
}
},
"example": {
"error": "Model is overloaded"
}
}
}
},
"500": {
"description": "Incomplete generation",
"content": {
"text/event-stream ": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse"
}
},
"example": {
"error": "Incomplete generation"
}
}
}
}
},
"deprecated": false
}
}
},
"components": {
"schemas": {
"Details": {
"type": "object",
"required": [
"finish_reason",
"generated_tokens"
],
"properties": {
"finish_reason": {
"$ref": "#/components/schemas/FinishReason"
},
"generated_tokens": {
"type": "integer",
"format": "int32",
"example": 1
},
"prefill": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Token"
}
},
"seed": {
"type": "integer",
"format": "int64",
"example": 42
},
"tokens": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Token"
}
}
}
},
"ErrorResponse": {
"type": "object",
"required": [
"error"
],
"properties": {
"error": {
"type": "string"
}
}
},
"FinishReason": {
"type": "string",
"enum": [
"length",
"eos_token",
"stop_sequence"
]
},
"GenerateParameters": {
"type": "object",
"properties": {
"details": {
"type": "boolean",
"default": "true"
},
"do_sample": {
"type": "boolean",
"default": "false",
"example": true
},
"max_new_tokens": {
"type": "integer",
"format": "int32",
"default": "20",
"exclusiveMaximum": 512.0,
"exclusiveMinimum": 0.0
},
"repetition_penalty": {
"type": "number",
"format": "float",
"default": "null",
"example": 1.03,
"nullable": true,
"exclusiveMinimum": 0.0
},
"seed": {
"type": "integer",
"format": "int64"
},
"stop": {
"type": "array",
"items": {
"type": "string"
},
"example": [
"photographer"
],
"maxItems": 4
},
"temperature": {
"type": "number",
"format": "float",
"default": "null",
"example": 0.5,
"nullable": true,
"exclusiveMinimum": 0.0
},
"top_k": {
"type": "integer",
"format": "int32",
"default": "null",
"example": 10,
"nullable": true,
"exclusiveMinimum": 0.0
},
"top_p": {
"type": "number",
"format": "float",
"default": "null",
"example": 0.95,
"nullable": true,
"maximum": 1.0,
"exclusiveMinimum": 0.0
}
}
},
"GenerateRequest": {
"type": "object",
"required": [
"inputs"
],
"properties": {
"inputs": {
"type": "string",
"example": "My name is Olivier and I"
},
"parameters": {
"$ref": "#/components/schemas/GenerateParameters"
}
}
},
"GenerateResponse": {
"type": "object",
"required": [
"generated_text"
],
"properties": {
"details": {
"$ref": "#/components/schemas/Details"
},
"generated_text": {
"type": "string",
"example": "test"
}
}
},
"StreamDetails": {
"type": "object",
"required": [
"finish_reason",
"generated_tokens"
],
"properties": {
"finish_reason": {
"$ref": "#/components/schemas/FinishReason"
},
"generated_tokens": {
"type": "integer",
"format": "int32",
"example": 1
},
"seed": {
"type": "integer",
"format": "int64",
"example": 42
}
}
},
"StreamResponse": {
"type": "object",
"required": [
"token"
],
"properties": {
"details": {
"$ref": "#/components/schemas/StreamDetails"
},
"generated_text": {
"type": "string",
"default": "null",
"example": "test",
"nullable": true
},
"token": {
"$ref": "#/components/schemas/Token"
}
}
},
"Token": {
"type": "object",
"required": [
"id",
"text",
"logprob"
],
"properties": {
"id": {
"type": "integer",
"format": "int32",
"example": 0
},
"logprob": {
"type": "number",
"format": "float",
"example": -0.34,
"nullable": true
},
"text": {
"type": "string",
"example": "test"
}
}
}
}
},
"tags": [
{
"name": "Text Generation Inference",
"description": "Hugging Face Text Generation Inference API"
}
]
}
\ No newline at end of file
[package]
name = "text-generation-launcher"
version = "0.1.0"
version = "0.2.0"
edition = "2021"
authors = ["Olivier Dehaene"]
description = "Text Generation Launcher"
......
......@@ -19,7 +19,7 @@ use subprocess::{Popen, PopenConfig, PopenError, Redirection};
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(default_value = "bigscience/bloom-560m", long, env)]
model_name: String,
model_id: String,
#[clap(long, env)]
revision: Option<String>,
#[clap(long, env)]
......@@ -49,7 +49,7 @@ struct Args {
fn main() -> ExitCode {
// Pattern match configuration
let Args {
model_name,
model_id,
revision,
num_shard,
quantize,
......@@ -92,7 +92,7 @@ fn main() -> ExitCode {
// Start shard processes
for rank in 0..num_shard {
let model_name = model_name.clone();
let model_id = model_id.clone();
let revision = revision.clone();
let uds_path = shard_uds_path.clone();
let master_addr = master_addr.clone();
......@@ -101,7 +101,7 @@ fn main() -> ExitCode {
let shutdown_sender = shutdown_sender.clone();
thread::spawn(move || {
shard_manager(
model_name,
model_id,
revision,
quantize,
uds_path,
......@@ -167,7 +167,7 @@ fn main() -> ExitCode {
"--master-shard-uds-path".to_string(),
format!("{}-0", shard_uds_path),
"--tokenizer-name".to_string(),
model_name,
model_id,
];
if json_output {
......@@ -256,7 +256,7 @@ enum ShardStatus {
#[allow(clippy::too_many_arguments)]
fn shard_manager(
model_name: String,
model_id: String,
revision: Option<String>,
quantize: bool,
uds_path: String,
......@@ -278,7 +278,7 @@ fn shard_manager(
let mut shard_argv = vec![
"text-generation-server".to_string(),
"serve".to_string(),
model_name,
model_id,
"--uds-path".to_string(),
uds_path,
"--logger-level".to_string(),
......
[
{
"details": {
"finish_reason": "length",
"generated_tokens": 20,
"prefill": [
[
10264,
"Test",
null
],
[
8821,
" request",
-11.895094
]
],
"tokens": [
[
17,
".",
-1.8267941
],
[
1587,
"get",
-2.4674964
],
[
11,
"(",
-1.9060438
],
[
5,
"\"",
-1.2279553
],
[
4899,
"action",
-4.170306
],
[
5,
"\"",
-0.3247902
],
[
12,
")",
-1.0773602
],
[
30,
";",
-0.27640444
],
[
837,
"\n ",
-1.6970599
],
[
1320,
" if",
-1.4495552
],
[
375,
" (",
-0.2360998
],
[
4899,
"action",
-1.1916926
],
[
3535,
" ==",
-0.8918663
],
[
5109,
" null",
-0.39334255
],
[
12,
")",
-0.4321134
],
[
731,
" {",
-0.17701954
],
[
1260,
"\n ",
-0.07027287
],
[
10519,
" throw",
-1.3915133
],
[
2084,
" new",
-0.042013377
],
[
150858,
" RuntimeException",
-1.7330077
]
]
},
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException"
}
]
\ No newline at end of file
{
"details": {
"finish_reason": "length",
"generated_tokens": 20,
"prefill": [
{
"id": 10264,
"logprob": null,
"text": "Test"
},
{
"id": 8821,
"logprob": -11.894989,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 17,
"logprob": -1.8267672,
"text": "."
},
{
"id": 1587,
"logprob": -2.4674969,
"text": "get"
},
{
"id": 11,
"logprob": -1.906001,
"text": "("
},
{
"id": 5,
"logprob": -1.2279545,
"text": "\""
},
{
"id": 4899,
"logprob": -4.170299,
"text": "action"
},
{
"id": 5,
"logprob": -0.32478866,
"text": "\""
},
{
"id": 12,
"logprob": -1.0773665,
"text": ")"
},
{
"id": 30,
"logprob": -0.27640742,
"text": ";"
},
{
"id": 837,
"logprob": -1.6970354,
"text": "\n "
},
{
"id": 1320,
"logprob": -1.4495516,
"text": " if"
},
{
"id": 375,
"logprob": -0.23609057,
"text": " ("
},
{
"id": 4899,
"logprob": -1.1916996,
"text": "action"
},
{
"id": 3535,
"logprob": -0.8918753,
"text": " =="
},
{
"id": 5109,
"logprob": -0.3933342,
"text": " null"
},
{
"id": 12,
"logprob": -0.43212673,
"text": ")"
},
{
"id": 731,
"logprob": -0.17702064,
"text": " {"
},
{
"id": 1260,
"logprob": -0.07027565,
"text": "\n "
},
{
"id": 10519,
"logprob": -1.3915029,
"text": " throw"
},
{
"id": 2084,
"logprob": -0.04201372,
"text": " new"
},
{
"id": 150858,
"logprob": -1.7329919,
"text": " RuntimeException"
}
]
},
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException"
}
\ No newline at end of file
......@@ -9,11 +9,18 @@ use std::thread::sleep;
use std::time::Duration;
use subprocess::{Popen, PopenConfig, Redirection};
#[derive(Deserialize)]
pub struct Token {
id: u32,
text: String,
logprob: Option<f32>,
}
#[derive(Deserialize)]
struct Details {
finish_reason: String,
generated_tokens: u32,
tokens: Vec<(u32, String, Option<f32>)>,
tokens: Vec<Token>,
}
#[derive(Deserialize)]
......@@ -22,11 +29,11 @@ struct GeneratedText {
details: Details,
}
fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Popen {
fn start_launcher(model_id: String, num_shard: usize, port: usize, master_port: usize) -> Popen {
let argv = vec![
"text-generation-launcher".to_string(),
"--model-name".to_string(),
model_name.clone(),
"--model-id".to_string(),
model_id.clone(),
"--num-shard".to_string(),
num_shard.to_string(),
"--port".to_string(),
......@@ -68,16 +75,16 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
launcher.terminate().unwrap();
launcher.wait().unwrap();
panic!("failed to launch {}", model_name)
panic!("failed to launch {}", model_id)
}
fn test_model(
model_name: String,
model_id: String,
num_shard: usize,
port: usize,
master_port: usize,
) -> GeneratedText {
let mut launcher = start_launcher(model_name, num_shard, port, master_port);
let mut launcher = start_launcher(model_id, num_shard, port, master_port);
let data = r#"
{
......@@ -109,8 +116,8 @@ fn read_json(name: &str) -> GeneratedText {
let file = File::open(d).unwrap();
let reader = BufReader::new(file);
let mut results: Vec<GeneratedText> = serde_json::from_reader(reader).unwrap();
results.pop().unwrap()
let result: GeneratedText = serde_json::from_reader(reader).unwrap();
result
}
fn compare_results(result: GeneratedText, expected: GeneratedText) {
......@@ -127,13 +134,13 @@ fn compare_results(result: GeneratedText, expected: GeneratedText) {
.into_iter()
.zip(expected.details.tokens.into_iter())
{
assert_eq!(token.0, expected_token.0);
assert_eq!(token.1, expected_token.1);
if let Some(logprob) = token.2 {
let expected_logprob = expected_token.2.unwrap();
assert_eq!(token.id, expected_token.id);
assert_eq!(token.text, expected_token.text);
if let Some(logprob) = token.logprob {
let expected_logprob = expected_token.logprob.unwrap();
assert_float_eq!(logprob, expected_logprob, abs <= 0.001);
} else {
assert_eq!(token.2, expected_token.2);
assert_eq!(token.logprob, expected_token.logprob);
}
}
}
......
[
{
"details": {
"finish_reason": "length",
"generated_tokens": 20,
"prefill": [
[
0,
"<pad>",
null
]
],
"tokens": [
[
259,
"",
-1.3656927
],
[
215100,
"\"\"\"",
-2.6551573
],
[
46138,
"Test",
-1.8059857
],
[
287,
"the",
-1.2102449
],
[
259,
"",
-1.6057279
],
[
49076,
"contents",
-3.6060903
],
[
304,
"of",
-0.5270343
],
[
287,
"the",
-0.62522805
],
[
259,
"",
-1.4069618
],
[
49076,
"contents",
-2.621994
],
[
304,
"of",
-1.3172221
],
[
287,
"the",
-0.3501925
],
[
259,
"",
-0.7219573
],
[
49076,
"contents",
-1.0494149
],
[
260,
".",
-1.0803378
],
[
259,
"",
-0.32933083
],
[
215100,
"\"\"\"",
-0.11268901
],
[
2978,
"test",
-1.5846587
],
[
290,
"_",
-0.49796978
],
[
4125,
"test",
-2.0026445
]
]
},
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test"
}
]
\ No newline at end of file
{
"details": {
"finish_reason": "length",
"generated_tokens": 20,
"prefill": [
{
"id": 0,
"logprob": null,
"text": "<pad>"
}
],
"seed": null,
"tokens": [
{
"id": 259,
"logprob": -1.3656927,
"text": ""
},
{
"id": 215100,
"logprob": -2.6551573,
"text": "\"\"\""
},
{
"id": 46138,
"logprob": -1.8059857,
"text": "Test"
},
{
"id": 287,
"logprob": -1.2102449,
"text": "the"
},
{
"id": 259,
"logprob": -1.6057279,
"text": ""
},
{
"id": 49076,
"logprob": -3.6060903,
"text": "contents"
},
{
"id": 304,
"logprob": -0.5270343,
"text": "of"
},
{
"id": 287,
"logprob": -0.62522805,
"text": "the"
},
{
"id": 259,
"logprob": -1.4069618,
"text": ""
},
{
"id": 49076,
"logprob": -2.621994,
"text": "contents"
},
{
"id": 304,
"logprob": -1.3172221,
"text": "of"
},
{
"id": 287,
"logprob": -0.3501925,
"text": "the"
},
{
"id": 259,
"logprob": -0.7219573,
"text": ""
},
{
"id": 49076,
"logprob": -1.0494149,
"text": "contents"
},
{
"id": 260,
"logprob": -1.0803378,
"text": "."
},
{
"id": 259,
"logprob": -0.32933083,
"text": ""
},
{
"id": 215100,
"logprob": -0.11268901,
"text": "\"\"\""
},
{
"id": 2978,
"logprob": -1.5846587,
"text": "test"
},
{
"id": 290,
"logprob": -0.49796978,
"text": "_"
},
{
"id": 4125,
"logprob": -2.0026445,
"text": "test"
}
]
},
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test"
}
\ No newline at end of file
......@@ -71,13 +71,19 @@ message Batch {
uint32 size = 3;
}
enum FinishReason {
FINISH_REASON_LENGTH = 0;
FINISH_REASON_EOS_TOKEN = 1;
FINISH_REASON_STOP_SEQUENCE = 2;
}
message GeneratedText {
/// Output
string text = 1;
/// Number of generated tokens
uint32 generated_tokens = 2;
/// Finish reason
string finish_reason = 3;
FinishReason finish_reason = 3;
/// Seed
optional uint64 seed = 4;
}
......
[package]
name = "text-generation-router"
version = "0.1.0"
version = "0.2.0"
edition = "2021"
authors = ["Olivier Dehaene"]
description = "Text Generation Webserver"
......@@ -14,7 +14,7 @@ path = "src/main.rs"
[dependencies]
async-stream = "0.3.3"
axum = { version = "0.5.16", features = ["json", "serde_json"] }
axum = { version = "0.6.4", features = ["json"] }
text-generation-client = { path = "client" }
clap = { version = "4.0.15", features = ["derive", "env"] }
futures = "0.3.24"
......@@ -29,4 +29,6 @@ tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot"
tokio-stream = "0.1.11"
tracing = "0.1.36"
tracing-subscriber = { version = "0.3.15", features = ["json"] }
utoipa = { version = "3.0.1", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
[package]
name = "text-generation-client"
version = "0.1.0"
version = "0.2.0"
edition = "2021"
[dependencies]
......
......@@ -7,8 +7,8 @@ mod sharded_client;
pub use client::Client;
pub use pb::generate::v1::{
Batch, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens, Request,
StoppingCriteriaParameters,
Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens,
Request, StoppingCriteriaParameters,
};
pub use sharded_client::ShardedClient;
use thiserror::Error;
......
......@@ -127,7 +127,7 @@ impl Infer {
.into_iter()
.zip(tokens.logprobs.into_iter())
.zip(tokens.texts.into_iter())
.map(|((id, logprob), text)| Token(id, text, logprob))
.map(|((id, logprob), text)| Token { id, text, logprob })
.collect();
}
// Push last token
......@@ -282,11 +282,11 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
}
// Create last Token
let token = Token(
generation.token_id,
generation.token_text,
generation.token_logprob,
);
let token = Token {
id: generation.token_id,
text: generation.token_text,
logprob: generation.token_logprob,
};
if let Some(generated_text) = generation.generated_text {
// Remove entry as this is the last message
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment